Merge branch 'main' into feat/refactor_generation_backend

This commit is contained in:
Sergey Borisov 2023-08-10 04:32:16 +03:00
commit ade78b9591
43 changed files with 1970 additions and 407 deletions

View File

@ -306,13 +306,30 @@ InvokeAI. The second will prepare the 2.3 directory for use with 3.0.
You may now launch the WebUI in the usual way, by selecting option [1]
from the launcher script
#### Migration Caveats
#### Migrating Images
The migration script will migrate your invokeai settings and models,
including textual inversion models, LoRAs and merges that you may have
installed previously. However it does **not** migrate the generated
images stored in your 2.3-format outputs directory. You will need to
manually import selected images into the 3.0 gallery via drag-and-drop.
images stored in your 2.3-format outputs directory. To do this, you
need to run an additional step:
1. From a working InvokeAI 3.0 root directory, start the launcher and
enter menu option [8] to open the "developer's console".
2. At the developer's console command line, type the command:
```bash
invokeai-import-images
```
3. This will lead you through the process of confirming the desired
source and destination for the imported images. The images will
appear in the gallery board of your choice, and contain the
original prompt, model name, and other parameters used to generate
the image.
(Many kudos to **techjedi** for contributing this script.)
## Hardware Requirements

View File

@ -264,7 +264,7 @@ experimental versions later.
you can create several levels of subfolders and drop your models into
whichever ones you want.
- ***Autoimport FolderLICENSE***
- ***LICENSE***
At the bottom of the screen you will see a checkbox for accepting
the CreativeML Responsible AI Licenses. You need to accept the license

View File

@ -55,7 +55,7 @@ logger = InvokeAILogger.getLogger()
class ApiDependencies:
"""Contains and initializes all dependencies for the API"""
invoker: Optional[Invoker] = None
invoker: Invoker
@staticmethod
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger):
@ -68,8 +68,9 @@ class ApiDependencies:
output_folder = config.output_path
# TODO: build a file/path manager?
db_location = config.db_path
db_location.parent.mkdir(parents=True, exist_ok=True)
db_path = config.db_path
db_path.parent.mkdir(parents=True, exist_ok=True)
db_location = str(db_path)
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"

View File

@ -1,22 +1,20 @@
import io
from typing import Optional
from PIL import Image
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.responses import FileResponse
from fastapi.routing import APIRouter
from PIL import Image
from pydantic import BaseModel, Field
from pydantic import BaseModel
from invokeai.app.invocations.metadata import ImageMetadata
from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.services.models.image_record import (
ImageDTO,
ImageRecordChanges,
ImageUrlsDTO,
)
from ..dependencies import ApiDependencies
images_router = APIRouter(prefix="/v1/images", tags=["images"])
@ -152,8 +150,9 @@ async def get_image_metadata(
raise HTTPException(status_code=404)
@images_router.get(
@images_router.api_route(
"/i/{image_name}/full",
methods=["GET", "HEAD"],
operation_id="get_image_full",
response_class=Response,
responses={

View File

@ -4,6 +4,7 @@ from typing import Literal, Optional
import cv2
import numpy
import cv2
from PIL import Image, ImageFilter, ImageOps, ImageChops
from pydantic import Field
from pathlib import Path
@ -502,7 +503,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
image = context.services.images.get_pil_image(self.image.image_name)
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
image_arr = image_arr * (self.max - self.min) + self.max
image_arr = image_arr * (self.max - self.min) + self.min
lerp_image = Image.fromarray(numpy.uint8(image_arr))
@ -653,6 +654,7 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
height=image_dto.height,
)
class MaskEdgeInvocation(BaseInvocation, PILInvocationConfig):
"""Applies an edge mask to an image"""
@ -702,6 +704,7 @@ class MaskEdgeInvocation(BaseInvocation, PILInvocationConfig):
height=image_dto.height,
)
class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["color_correct"] = "color_correct"
@ -817,3 +820,142 @@ class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig):
height=image_dto.height,
)
class ImageHueAdjustmentInvocation(BaseInvocation):
"""Adjusts the Hue of an image."""
# fmt: off
type: Literal["img_hue_adjust"] = "img_hue_adjust"
# Inputs
image: ImageField = Field(default=None, description="The image to adjust")
hue: int = Field(default=0, description="The degrees by which to rotate the hue, 0-360")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.services.images.get_pil_image(self.image.image_name)
# Convert image to HSV color space
hsv_image = numpy.array(pil_image.convert("HSV"))
# Convert hue from 0..360 to 0..256
hue = int(256 * ((self.hue % 360) / 360))
# Increment each hue and wrap around at 255
hsv_image[:, :, 0] = (hsv_image[:, :, 0] + hue) % 256
# Convert back to PIL format and to original color mode
pil_image = Image.fromarray(hsv_image, mode="HSV").convert("RGBA")
image_dto = context.services.images.create(
image=pil_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
),
width=image_dto.width,
height=image_dto.height,
)
class ImageLuminosityAdjustmentInvocation(BaseInvocation):
"""Adjusts the Luminosity (Value) of an image."""
# fmt: off
type: Literal["img_luminosity_adjust"] = "img_luminosity_adjust"
# Inputs
image: ImageField = Field(default=None, description="The image to adjust")
luminosity: float = Field(default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.services.images.get_pil_image(self.image.image_name)
# Convert PIL image to OpenCV format (numpy array), note color channel
# ordering is changed from RGB to BGR
image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1]
# Convert image to HSV color space
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
# Adjust the luminosity (value)
hsv_image[:, :, 2] = numpy.clip(hsv_image[:, :, 2] * self.luminosity, 0, 255)
# Convert image back to BGR color space
image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
# Convert back to PIL format and to original color mode
pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA")
image_dto = context.services.images.create(
image=pil_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
),
width=image_dto.width,
height=image_dto.height,
)
class ImageSaturationAdjustmentInvocation(BaseInvocation):
"""Adjusts the Saturation of an image."""
# fmt: off
type: Literal["img_saturation_adjust"] = "img_saturation_adjust"
# Inputs
image: ImageField = Field(default=None, description="The image to adjust")
saturation: float = Field(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.services.images.get_pil_image(self.image.image_name)
# Convert PIL image to OpenCV format (numpy array), note color channel
# ordering is changed from RGB to BGR
image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1]
# Convert image to HSV color space
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
# Adjust the saturation
hsv_image[:, :, 1] = numpy.clip(hsv_image[:, :, 1] * self.saturation, 0, 255)
# Convert image back to BGR color space
image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
# Convert back to PIL format and to original color mode
pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA")
image_dto = context.services.images.create(
image=pil_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -5,15 +5,26 @@ from typing import List, Literal, Optional, Union
import einops
import torch
from diffusers import ControlNetModel
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .compel import ConditioningField
from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ...backend.model_management import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import (
@ -239,7 +250,6 @@ class TextToLatentsInvocation(BaseInvocation):
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
precision="float16" if unet.dtype == torch.float16 else "float32",
)
def prep_control_data(

View File

@ -24,11 +24,10 @@ InvokeAI:
sequential_guidance: false
precision: float16
max_cache_size: 6
max_vram_cache_size: 2.7
max_vram_cache_size: 0.5
always_use_cpu: false
free_gpu_mem: false
Features:
restore: true
esrgan: true
patchmatch: true
internet_available: true
@ -165,7 +164,7 @@ import pydoc
import os
import sys
from argparse import ArgumentParser
from omegaconf import OmegaConf, DictConfig
from omegaconf import OmegaConf, DictConfig, ListConfig
from pathlib import Path
from pydantic import BaseSettings, Field, parse_obj_as
from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args
@ -173,6 +172,7 @@ from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_ty
INIT_FILE = Path("invokeai.yaml")
DB_FILE = Path("invokeai.db")
LEGACY_INIT_FILE = Path("invokeai.init")
DEFAULT_MAX_VRAM = 0.5
class InvokeAISettings(BaseSettings):
@ -189,7 +189,12 @@ class InvokeAISettings(BaseSettings):
opt = parser.parse_args(argv)
for name in self.__fields__:
if name not in self._excluded():
setattr(self, name, getattr(opt, name))
value = getattr(opt, name)
if isinstance(value, ListConfig):
value = list(value)
elif isinstance(value, DictConfig):
value = dict(value)
setattr(self, name, value)
def to_yaml(self) -> str:
"""
@ -282,14 +287,10 @@ class InvokeAISettings(BaseSettings):
return [
"type",
"initconf",
"gpu_mem_reserved",
"max_loaded_models",
"version",
"from_file",
"model",
"restore",
"root",
"nsfw_checker",
]
class Config:
@ -388,15 +389,11 @@ class InvokeAIAppConfig(InvokeAISettings):
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
restore : bool = Field(default=True, description="Enable/disable face restoration code (DEPRECATED)", category='DEPRECATED')
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
max_loaded_models : int = Field(default=3, gt=0, description="(DEPRECATED: use max_cache_size) Maximum number of models to keep in memory for rapid switching", category='DEPRECATED')
max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
max_vram_cache_size : float = Field(default=2.75, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
gpu_mem_reserved : float = Field(default=2.75, ge=0, description="DEPRECATED: use max_vram_cache_size. Amount of VRAM reserved for model storage", category='DEPRECATED')
nsfw_checker : bool = Field(default=True, description="DEPRECATED: use Web settings to enable/disable", category='DEPRECATED')
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='auto',description='Floating point precision', category='Memory/Performance')
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
@ -414,9 +411,7 @@ class InvokeAIAppConfig(InvokeAISettings):
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert')
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', category='Features')
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
@ -426,6 +421,9 @@ class InvokeAIAppConfig(InvokeAISettings):
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
# fmt: on
class Config:
validate_assignment = True
def parse_args(self, argv: List[str] = None, conf: DictConfig = None, clobber=False):
"""
Update settings with contents of init file, environment, and

View File

@ -3,9 +3,10 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from logging import Logger
from pathlib import Path
from pydantic import Field
from typing import Optional, Union, Callable, List, Tuple, TYPE_CHECKING
from typing import Literal, Optional, Union, Callable, List, Tuple, TYPE_CHECKING
from types import ModuleType
from invokeai.backend.model_management import (
@ -193,7 +194,7 @@ class ModelManagerServiceBase(ABC):
self,
model_name: str,
base_model: BaseModelType,
model_type: Union[ModelType.Main, ModelType.Vae],
model_type: Literal[ModelType.Main, ModelType.Vae],
) -> AddModelResult:
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
@ -292,7 +293,7 @@ class ModelManagerService(ModelManagerServiceBase):
def __init__(
self,
config: InvokeAIAppConfig,
logger: ModuleType,
logger: Logger,
):
"""
Initialize with the path to the models.yaml config file.
@ -396,7 +397,7 @@ class ModelManagerService(ModelManagerServiceBase):
model_type,
)
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
"""
Given a model name returns a dict-like (OmegaConf) object describing it.
"""
@ -416,7 +417,7 @@ class ModelManagerService(ModelManagerServiceBase):
"""
return self.mgr.list_models(base_model, model_type)
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
"""
Return information about the model using the same format as list_models()
"""
@ -429,7 +430,7 @@ class ModelManagerService(ModelManagerServiceBase):
model_type: ModelType,
model_attributes: dict,
clobber: bool = False,
) -> None:
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
@ -478,7 +479,7 @@ class ModelManagerService(ModelManagerServiceBase):
self,
model_name: str,
base_model: BaseModelType,
model_type: Union[ModelType.Main, ModelType.Vae],
model_type: Literal[ModelType.Main, ModelType.Vae],
convert_dest_directory: Optional[Path] = Field(
default=None, description="Optional directory location for merged model"
),
@ -573,9 +574,9 @@ class ModelManagerService(ModelManagerServiceBase):
default=None, description="Base model shared by all models to be merged"
),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
alpha: float = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
force: bool = False,
merge_dest_directory: Optional[Path] = Field(
default=None, description="Optional directory location for merged model"
),
@ -633,8 +634,8 @@ class ModelManagerService(ModelManagerServiceBase):
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: str = None,
new_base: BaseModelType = None,
new_name: Optional[str] = None,
new_base: Optional[BaseModelType] = None,
):
"""
Rename the indicated model. Can provide a new name and/or a new base.

View File

@ -10,12 +10,15 @@ import sys
import argparse
import io
import os
import psutil
import shutil
import textwrap
import torch
import traceback
import yaml
import warnings
from argparse import Namespace
from enum import Enum
from pathlib import Path
from shutil import get_terminal_size
from typing import get_type_hints
@ -44,6 +47,8 @@ from invokeai.app.services.config import (
)
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
# TO DO - Move all the frontend code into invokeai.frontend.install
from invokeai.frontend.install.widgets import (
SingleSelectColumns,
CenteredButtonPress,
@ -53,6 +58,7 @@ from invokeai.frontend.install.widgets import (
CyclingForm,
MIN_COLS,
MIN_LINES,
WindowTooSmallException,
)
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
from invokeai.backend.install.model_install_backend import (
@ -61,6 +67,7 @@ from invokeai.backend.install.model_install_backend import (
ModelInstall,
)
from invokeai.backend.model_management.model_probe import ModelType, BaseModelType
from pydantic.error_wrappers import ValidationError
warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error()
@ -76,6 +83,13 @@ Default_config_file = config.model_conf_path
SD_Configs = config.legacy_conf_path
PRECISION_CHOICES = ["auto", "float16", "float32"]
GB = 1073741824 # GB in bytes
HAS_CUDA = torch.cuda.is_available()
_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0, 0)
MAX_VRAM /= GB
MAX_RAM = psutil.virtual_memory().total / GB
INIT_FILE_PREAMBLE = """# InvokeAI initialization file
# This is the InvokeAI initialization file, which contains command-line default values.
@ -86,6 +100,12 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
logger = InvokeAILogger.getLogger()
class DummyWidgetValue(Enum):
zero = 0
true = True
false = False
# --------------------------------------------
def postscript(errors: None):
if not any(errors):
@ -378,13 +398,35 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
)
self.max_cache_size = self.add_widget_intelligent(
IntTitleSlider,
name="Size of the RAM cache used for fast model switching (GB)",
name="RAM cache size (GB). Make this at least large enough to hold a single full model.",
value=old_opts.max_cache_size,
out_of=20,
out_of=MAX_RAM,
lowest=3,
begin_entry_at=6,
scroll_exit=True,
)
if HAS_CUDA:
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="VRAM cache size (GB). Reserving a small amount of VRAM will modestly speed up the start of image generation.",
begin_entry_at=0,
editable=False,
color="CONTROL",
scroll_exit=True,
)
self.nextrely -= 1
self.max_vram_cache_size = self.add_widget_intelligent(
npyscreen.Slider,
value=old_opts.max_vram_cache_size,
out_of=round(MAX_VRAM * 2) / 2,
lowest=0.0,
relx=8,
step=0.25,
scroll_exit=True,
)
else:
self.max_vram_cache_size = DummyWidgetValue.zero
self.nextrely += 1
self.outdir = self.add_widget_intelligent(
FileBox,
@ -401,7 +443,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
self.autoimport_dirs = {}
self.autoimport_dirs["autoimport_dir"] = self.add_widget_intelligent(
FileBox,
name=f"Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models",
name="Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models",
value=str(config.root_path / config.autoimport_dir),
select_dir=True,
must_exist=False,
@ -476,6 +518,7 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
"outdir",
"free_gpu_mem",
"max_cache_size",
"max_vram_cache_size",
"xformers_enabled",
"always_use_cpu",
]:
@ -592,13 +635,13 @@ def maybe_create_models_yaml(root: Path):
# -------------------------------------
def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace, Namespace):
# parse_args() will read from init file if present
invokeai_opts = default_startup_options(initfile)
invokeai_opts.root = program_opts.root
# The third argument is needed in the Windows 11 environment to
# launch a console window running this program.
set_min_terminal_size(MIN_COLS, MIN_LINES)
if not set_min_terminal_size(MIN_COLS, MIN_LINES):
raise WindowTooSmallException(
"Could not increase terminal size. Try running again with a larger window or smaller font size."
)
# the install-models application spawns a subprocess to install
# models, and will crash unless this is set before running.
@ -654,10 +697,13 @@ def migrate_init_file(legacy_format: Path):
old = legacy_parser.parse_args([f"@{str(legacy_format)}"])
new = InvokeAIAppConfig.get_config()
fields = list(get_type_hints(InvokeAIAppConfig).keys())
fields = [x for x, y in InvokeAIAppConfig.__fields__.items() if y.field_info.extra.get("category") != "DEPRECATED"]
for attr in fields:
if hasattr(old, attr):
setattr(new, attr, getattr(old, attr))
try:
setattr(new, attr, getattr(old, attr))
except ValidationError as e:
print(f"* Ignoring incompatible value for field {attr}:\n {str(e)}")
# a few places where the field names have changed and we have to
# manually add in the new names/values
@ -777,6 +823,7 @@ def main():
models_to_download = default_user_selections(opt)
new_init_file = config.root_path / "invokeai.yaml"
if opt.yes_to_all:
write_default_options(opt, new_init_file)
init_options = Namespace(precision="float32" if opt.full_precision else "float16")
@ -802,6 +849,8 @@ def main():
postscript(errors=errors)
if not opt.yes_to_all:
input("Press any key to continue...")
except WindowTooSmallException as e:
logger.error(str(e))
except KeyboardInterrupt:
print("\nGoodbye! Come back soon.")

View File

@ -101,9 +101,9 @@ class ModelInstall(object):
def __init__(
self,
config: InvokeAIAppConfig,
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
model_manager: ModelManager = None,
access_token: str = None,
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
model_manager: Optional[ModelManager] = None,
access_token: Optional[str] = None,
):
self.config = config
self.mgr = model_manager or ModelManager(config.model_conf_path)

View File

@ -228,19 +228,19 @@ the root is the InvokeAI ROOTDIR.
"""
from __future__ import annotations
import os
import hashlib
import os
import textwrap
import yaml
import types
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types
from shutil import rmtree, move
from typing import Optional, List, Literal, Tuple, Union, Dict, Set, Callable
import torch
import yaml
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from pydantic import BaseModel, Field
import invokeai.backend.util.logging as logger
@ -259,6 +259,7 @@ from .models import (
ModelNotFoundException,
InvalidModelException,
DuplicateModelException,
ModelBase,
)
# We are only starting to number the config file with release 3.
@ -361,7 +362,7 @@ class ModelManager(object):
if model_key.startswith("_"):
continue
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
model_class = self._get_implementation(base_model, model_type)
# alias for config file
model_config["model_format"] = model_config.pop("format")
self.models[model_key] = model_class.create_config(**model_config)
@ -381,18 +382,24 @@ class ModelManager(object):
# causing otherwise unreferenced models to be removed from memory
self._read_models()
def model_exists(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> bool:
def model_exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType, *, rescan=False) -> bool:
"""
Given a model name, returns True if it is a valid
identifier.
Given a model name, returns True if it is a valid identifier.
:param model_name: symbolic name of the model in models.yaml
:param model_type: ModelType enum indicating the type of model to return
:param base_model: BaseModelType enum indicating the base model used by this model
:param rescan: if True, scan_models_directory
"""
model_key = self.create_key(model_name, base_model, model_type)
return model_key in self.models
exists = model_key in self.models
# if model not found try to find it (maybe file just pasted)
if rescan and not exists:
self.scan_models_directory(base_model=base_model, model_type=model_type)
exists = self.model_exists(model_name, base_model, model_type, rescan=False)
return exists
@classmethod
def create_key(
@ -443,39 +450,32 @@ class ModelManager(object):
:param model_name: symbolic name of the model in models.yaml
:param model_type: ModelType enum indicating the type of model to return
:param base_model: BaseModelType enum indicating the base model used by this model
:param submode_typel: an ModelType enum indicating the portion of
:param submodel_type: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae)
"""
model_class = MODEL_CLASSES[base_model][model_type]
model_key = self.create_key(model_name, base_model, model_type)
# if model not found try to find it (maybe file just pasted)
if model_key not in self.models:
self.scan_models_directory(base_model=base_model, model_type=model_type)
if model_key not in self.models:
raise ModelNotFoundException(f"Model not found - {model_key}")
if not self.model_exists(model_name, base_model, model_type, rescan=True):
raise ModelNotFoundException(f"Model not found - {model_key}")
model_config = self.models[model_key]
model_path = self.resolve_model_path(model_config.path)
model_config = self._get_model_config(base_model, model_name, model_type)
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
if is_submodel_override:
model_type = submodel_type
submodel_type = None
model_class = self._get_implementation(base_model, model_type)
if not model_path.exists():
if model_class.save_to_config:
self.models[model_key].error = ModelError.NotFound
raise Exception(f'Files for model "{model_key}" not found')
raise Exception(f'Files for model "{model_key}" not found at {model_path}')
else:
self.models.pop(model_key, None)
raise ModelNotFoundException(f"Model not found - {model_key}")
# vae/movq override
# TODO:
if submodel_type is not None and hasattr(model_config, submodel_type):
override_path = getattr(model_config, submodel_type)
if override_path:
model_path = self.resolve_path(override_path)
model_type = submodel_type
submodel_type = None
model_class = MODEL_CLASSES[base_model][model_type]
raise ModelNotFoundException(f'Files for model "{model_key}" not found at {model_path}')
# TODO: path
# TODO: is it accurate to use path as id
@ -513,12 +513,61 @@ class ModelManager(object):
_cache=self.cache,
)
def _get_model_path(
self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None
) -> (Path, bool):
"""Extract a model's filesystem path from its config.
:return: The fully qualified Path of the module (or submodule).
"""
model_path = model_config.path
is_submodel_override = False
# Does the config explicitly override the submodel?
if submodel_type is not None and hasattr(model_config, submodel_type):
submodel_path = getattr(model_config, submodel_type)
if submodel_path is not None:
model_path = getattr(model_config, submodel_type)
is_submodel_override = True
model_path = self.resolve_model_path(model_path)
return model_path, is_submodel_override
def _get_model_config(self, base_model: BaseModelType, model_name: str, model_type: ModelType) -> ModelConfigBase:
"""Get a model's config object."""
model_key = self.create_key(model_name, base_model, model_type)
try:
model_config = self.models[model_key]
except KeyError:
raise ModelNotFoundException(f"Model not found - {model_key}")
return model_config
def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
"""Get the concrete implementation class for a specific model type."""
model_class = MODEL_CLASSES[base_model][model_type]
return model_class
def _instantiate(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel_type: Optional[SubModelType] = None,
) -> ModelBase:
"""Make a new instance of this model, without loading it."""
model_config = self._get_model_config(base_model, model_name, model_type)
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
# FIXME: do non-overriden submodels get the right class?
constructor = self._get_implementation(base_model, model_type)
instance = constructor(model_path, base_model, model_type)
return instance
def model_info(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> dict:
) -> Union[dict, None]:
"""
Given a model name returns the OmegaConf (dict-like) object describing it.
"""
@ -540,13 +589,16 @@ class ModelManager(object):
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> dict:
) -> Union[dict, None]:
"""
Returns a dict describing one installed model, using
the combined format of the list_models() method.
"""
models = self.list_models(base_model, model_type, model_name)
return models[0] if models else None
if len(models) >= 1:
return models[0]
else:
return None
def list_models(
self,
@ -560,7 +612,7 @@ class ModelManager(object):
model_keys = (
[self.create_key(model_name, base_model, model_type)]
if model_name
if model_name and base_model and model_type
else sorted(self.models, key=str.casefold)
)
models = []
@ -596,7 +648,7 @@ class ModelManager(object):
Print a table of models and their descriptions. This needs to be redone
"""
# TODO: redo
for model_type, model_dict in self.list_models().items():
for model_dict in self.list_models():
for model_name, model_info in model_dict.items():
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
print(line)
@ -658,7 +710,7 @@ class ModelManager(object):
if path := model_attributes.get("path"):
model_attributes["path"] = str(self.relative_model_path(Path(path)))
model_class = MODEL_CLASSES[base_model][model_type]
model_class = self._get_implementation(base_model, model_type)
model_config = model_class.create_config(**model_attributes)
model_key = self.create_key(model_name, base_model, model_type)
@ -699,8 +751,8 @@ class ModelManager(object):
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: str = None,
new_base: BaseModelType = None,
new_name: Optional[str] = None,
new_base: Optional[BaseModelType] = None,
):
"""
Rename or rebase a model.
@ -753,7 +805,7 @@ class ModelManager(object):
self,
model_name: str,
base_model: BaseModelType,
model_type: Union[ModelType.Main, ModelType.Vae],
model_type: Literal[ModelType.Main, ModelType.Vae],
dest_directory: Optional[Path] = None,
) -> AddModelResult:
"""
@ -767,6 +819,10 @@ class ModelManager(object):
This will raise a ValueError unless the model is a checkpoint.
"""
info = self.model_info(model_name, base_model, model_type)
if info is None:
raise FileNotFoundError(f"model not found: {model_name}")
if info["model_format"] != "checkpoint":
raise ValueError(f"not a checkpoint format model: {model_name}")
@ -836,7 +892,7 @@ class ModelManager(object):
return search_folder, found_models
def commit(self, conf_file: Path = None) -> None:
def commit(self, conf_file: Optional[Path] = None) -> None:
"""
Write current configuration out to the indicated file.
"""
@ -845,7 +901,7 @@ class ModelManager(object):
for model_key, model_config in self.models.items():
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
model_class = self._get_implementation(base_model, model_type)
if model_class.save_to_config:
# TODO: or exclude_unset better fits here?
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
@ -903,7 +959,7 @@ class ModelManager(object):
model_path = self.resolve_model_path(model_config.path).absolute()
if not model_path.exists():
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
model_class = self._get_implementation(cur_base_model, cur_model_type)
if model_class.save_to_config:
model_config.error = ModelError.NotFound
self.models.pop(model_key, None)
@ -919,7 +975,7 @@ class ModelManager(object):
for cur_model_type in ModelType:
if model_type is not None and cur_model_type != model_type:
continue
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
model_class = self._get_implementation(cur_base_model, cur_model_type)
models_dir = self.resolve_model_path(Path(cur_base_model.value, cur_model_type.value))
if not models_dir.exists():
@ -935,7 +991,9 @@ class ModelManager(object):
raise DuplicateModelException(f"Model with key {model_key} added twice")
model_path = self.relative_model_path(model_path)
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
model_config: ModelConfigBase = model_class.probe_config(
str(model_path), model_base=cur_base_model
)
self.models[model_key] = model_config
new_models_found = True
except DuplicateModelException as e:
@ -983,7 +1041,7 @@ class ModelManager(object):
# LS: hacky
# Patch in the SD VAE from core so that it is available for use by the UI
try:
self.heuristic_import({self.resolve_model_path("core/convert/sd-vae-ft-mse")})
self.heuristic_import({str(self.resolve_model_path("core/convert/sd-vae-ft-mse"))})
except:
pass
@ -1011,7 +1069,7 @@ class ModelManager(object):
def heuristic_import(
self,
items_to_import: Set[str],
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> Dict[str, AddModelResult]:
"""Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.

View File

@ -33,7 +33,7 @@ class ModelMerger(object):
self,
model_paths: List[Path],
alpha: float = 0.5,
interp: MergeInterpolationMethod = None,
interp: Optional[MergeInterpolationMethod] = None,
force: bool = False,
**kwargs,
) -> DiffusionPipeline:
@ -73,7 +73,7 @@ class ModelMerger(object):
base_model: Union[BaseModelType, str],
merged_model_name: str,
alpha: float = 0.5,
interp: MergeInterpolationMethod = None,
interp: Optional[MergeInterpolationMethod] = None,
force: bool = False,
merge_dest_directory: Optional[Path] = None,
**kwargs,
@ -122,7 +122,7 @@ class ModelMerger(object):
dump_path.mkdir(parents=True, exist_ok=True)
dump_path = dump_path / merged_model_name
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
attributes = dict(
path=str(dump_path),
description=f"Merge of models {', '.join(model_names)}",

View File

@ -80,8 +80,10 @@ class StableDiffusionXLModel(DiffusersModel):
raise Exception("Unkown stable diffusion 2.* model format")
if ckpt_config_path is None:
# TO DO: implement picking
pass
# avoid circular import
from .stable_diffusion import _select_ckpt_config
ckpt_config_path = _select_ckpt_config(kwargs.get("model_base", BaseModelType.StableDiffusionXL), variant)
return cls.create_config(
path=path,

View File

@ -1,9 +1,14 @@
import os
import torch
import safetensors
from enum import Enum
from pathlib import Path
from typing import Optional, Union, Literal
from typing import Optional
import safetensors
import torch
from diffusers.utils import is_safetensors_available
from omegaconf import OmegaConf
from invokeai.app.services.config import InvokeAIAppConfig
from .base import (
ModelBase,
ModelConfigBase,
@ -18,9 +23,6 @@ from .base import (
InvalidModelException,
ModelNotFoundException,
)
from invokeai.app.services.config import InvokeAIAppConfig
from diffusers.utils import is_safetensors_available
from omegaconf import OmegaConf
class VaeModelFormat(str, Enum):
@ -80,7 +82,7 @@ class VaeModel(ModelBase):
@classmethod
def detect_format(cls, path: str):
if not os.path.exists(path):
raise ModelNotFoundException()
raise ModelNotFoundException(f"Does not exist as local file: {path}")
if os.path.isdir(path):
if os.path.exists(os.path.join(path, "config.json")):

View File

@ -1,18 +1,19 @@
from __future__ import annotations
import dataclasses
import inspect
import math
import secrets
from dataclasses import dataclass, field
import inspect
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
from pydantic import Field
import math
import einops
import PIL.Image
import numpy as np
import einops
import psutil
import torch
import torchvision.transforms as T
from accelerate.utils import set_seed
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
@ -27,17 +28,18 @@ from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.outputs import BaseOutput
from pydantic import Field
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from typing_extensions import ParamSpec
from invokeai.app.services.config import InvokeAIAppConfig
from ..util import CPU_DEVICE, normalize_device
from .diffusion import (
AttentionMapSaver,
InvokeAIDiffuserComponent,
PostprocessingSettings,
)
from ..util import normalize_device
@dataclass
@ -292,9 +294,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
safety_checker: Optional[StableDiffusionSafetyChecker],
feature_extractor: Optional[CLIPFeatureExtractor],
requires_safety_checker: bool = False,
precision: str = "float32",
control_model: ControlNetModel = None,
execution_device: Optional[torch.device] = None,
):
super().__init__(
vae,
@ -335,12 +335,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
return
if self.device.type == "cpu" or self.device.type == "mps":
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
mem_free = psutil.virtual_memory().free
elif self.device.type == "cuda":
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device))
elif self.unet.device.type == "cuda":
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.unet.device))
else:
raise ValueError(f"unrecognized device {self.device}")
raise ValueError(f"unrecognized device {self.unet.device}")
# input tensor of [1, 4, h/8, w/8]
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
@ -363,10 +363,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
raise Exception("Should not be called")
@property
def device(self) -> torch.device:
return self.unet.device
def latents_from_embeddings(
self,
latents: torch.Tensor,

View File

@ -0,0 +1,795 @@
# Copyright (c) 2023 - The InvokeAI Team
# Primary Author: David Lovell (github @f412design, discord @techjedi)
# co-author, minor tweaks - Lincoln Stein
# pylint: disable=line-too-long
# pylint: disable=broad-exception-caught
"""Script to import images into the new database system for 3.0.0"""
import os
import datetime
import shutil
import locale
import sqlite3
import json
import glob
import re
import uuid
import yaml
import PIL
import PIL.ImageOps
import PIL.PngImagePlugin
from pathlib import Path
from prompt_toolkit import prompt
from prompt_toolkit.shortcuts import message_dialog
from prompt_toolkit.completion import PathCompleter
from prompt_toolkit.key_binding import KeyBindings
from invokeai.app.services.config import InvokeAIAppConfig
app_config = InvokeAIAppConfig.get_config()
bindings = KeyBindings()
@bindings.add("c-c")
def _(event):
raise KeyboardInterrupt
# release notes
# "Use All" with size dimensions not selectable in the UI will not load dimensions
class Config:
"""Configuration loader."""
def __init__(self):
pass
TIMESTAMP_STRING = datetime.datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
INVOKE_DIRNAME = "invokeai"
YAML_FILENAME = "invokeai.yaml"
DATABASE_FILENAME = "invokeai.db"
database_path = None
database_backup_dir = None
outputs_path = None
thumbnail_path = None
def find_and_load(self):
"""find the yaml config file and load"""
root = app_config.root_path
if not self.confirm_and_load(os.path.abspath(root)):
print("\r\nSpecify custom database and outputs paths:")
self.confirm_and_load_from_user()
self.database_backup_dir = os.path.join(os.path.dirname(self.database_path), "backup")
self.thumbnail_path = os.path.join(self.outputs_path, "thumbnails")
def confirm_and_load(self, invoke_root):
"""Validates a yaml path exists, confirms the user wants to use it and loads config."""
yaml_path = os.path.join(invoke_root, self.YAML_FILENAME)
if os.path.exists(yaml_path):
db_dir, outdir = self.load_paths_from_yaml(yaml_path)
if os.path.isabs(db_dir):
database_path = os.path.join(db_dir, self.DATABASE_FILENAME)
else:
database_path = os.path.join(invoke_root, db_dir, self.DATABASE_FILENAME)
if os.path.isabs(outdir):
outputs_path = os.path.join(outdir, "images")
else:
outputs_path = os.path.join(invoke_root, outdir, "images")
db_exists = os.path.exists(database_path)
outdir_exists = os.path.exists(outputs_path)
text = f"Found {self.YAML_FILENAME} file at {yaml_path}:"
text += f"\n Database : {database_path}"
text += f"\n Outputs : {outputs_path}"
text += "\n\nUse these paths for import (yes) or choose different ones (no) [Yn]: "
if db_exists and outdir_exists:
if (prompt(text).strip() or "Y").upper().startswith("Y"):
self.database_path = database_path
self.outputs_path = outputs_path
return True
else:
return False
else:
print(" Invalid: One or more paths in this config did not exist and cannot be used.")
else:
message_dialog(
title="Path not found",
text=f"Auto-discovery of configuration failed! Could not find ({yaml_path}), Custom paths can be specified.",
).run()
return False
def confirm_and_load_from_user(self):
default = ""
while True:
database_path = os.path.expanduser(
prompt(
"Database: Specify absolute path to the database to import into: ",
completer=PathCompleter(
expanduser=True, file_filter=lambda x: Path(x).is_dir() or x.endswith((".db"))
),
default=default,
)
)
if database_path.endswith(".db") and os.path.isabs(database_path) and os.path.exists(database_path):
break
default = database_path + "/" if Path(database_path).is_dir() else database_path
default = ""
while True:
outputs_path = os.path.expanduser(
prompt(
"Outputs: Specify absolute path to outputs/images directory to import into: ",
completer=PathCompleter(expanduser=True, only_directories=True),
default=default,
)
)
if outputs_path.endswith("images") and os.path.isabs(outputs_path) and os.path.exists(outputs_path):
break
default = outputs_path + "/" if Path(outputs_path).is_dir() else outputs_path
self.database_path = database_path
self.outputs_path = outputs_path
return
def load_paths_from_yaml(self, yaml_path):
"""Load an Invoke AI yaml file and get the database and outputs paths."""
try:
with open(yaml_path, "rt", encoding=locale.getpreferredencoding()) as file:
yamlinfo = yaml.safe_load(file)
db_dir = yamlinfo.get("InvokeAI", {}).get("Paths", {}).get("db_dir", None)
outdir = yamlinfo.get("InvokeAI", {}).get("Paths", {}).get("outdir", None)
return db_dir, outdir
except Exception:
print(f"Failed to load paths from yaml file! {yaml_path}!")
return None, None
class ImportStats:
"""DTO for tracking work progress."""
def __init__(self):
pass
time_start = datetime.datetime.utcnow()
count_source_files = 0
count_skipped_file_exists = 0
count_skipped_db_exists = 0
count_imported = 0
count_imported_by_version = {}
count_file_errors = 0
@staticmethod
def get_elapsed_time_string():
"""Get a friendly time string for the time elapsed since processing start."""
time_now = datetime.datetime.utcnow()
total_seconds = (time_now - ImportStats.time_start).total_seconds()
hours = int((total_seconds) / 3600)
minutes = int(((total_seconds) % 3600) / 60)
seconds = total_seconds % 60
out_str = f"{hours} hour(s) -" if hours > 0 else ""
out_str += f"{minutes} minute(s) -" if minutes > 0 else ""
out_str += f"{seconds:.2f} second(s)"
return out_str
class InvokeAIMetadata:
"""DTO for core Invoke AI generation properties parsed from metadata."""
def __init__(self):
pass
def __str__(self):
formatted_str = f"{self.generation_mode}~{self.steps}~{self.cfg_scale}~{self.model_name}~{self.scheduler}~{self.seed}~{self.width}~{self.height}~{self.rand_device}~{self.strength}~{self.init_image}"
formatted_str += f"\r\npositive_prompt: {self.positive_prompt}"
formatted_str += f"\r\nnegative_prompt: {self.negative_prompt}"
return formatted_str
generation_mode = None
steps = None
cfg_scale = None
model_name = None
scheduler = None
seed = None
width = None
height = None
rand_device = None
strength = None
init_image = None
positive_prompt = None
negative_prompt = None
imported_app_version = None
def to_json(self):
"""Convert the active instance to json format."""
prop_dict = {}
prop_dict["generation_mode"] = self.generation_mode
# dont render prompt nodes if neither are set to avoid the ui thinking it can set them
# if at least one exists, render them both, but use empty string instead of None if one of them is empty
# this allows the field that is empty to actually be cleared byt he UI instead of leaving the previous value
if self.positive_prompt or self.negative_prompt:
prop_dict["positive_prompt"] = "" if self.positive_prompt is None else self.positive_prompt
prop_dict["negative_prompt"] = "" if self.negative_prompt is None else self.negative_prompt
prop_dict["width"] = self.width
prop_dict["height"] = self.height
# only render seed if it has a value to avoid ui thinking it can set this and then error
if self.seed:
prop_dict["seed"] = self.seed
prop_dict["rand_device"] = self.rand_device
prop_dict["cfg_scale"] = self.cfg_scale
prop_dict["steps"] = self.steps
prop_dict["scheduler"] = self.scheduler
prop_dict["clip_skip"] = 0
prop_dict["model"] = {}
prop_dict["model"]["model_name"] = self.model_name
prop_dict["model"]["base_model"] = None
prop_dict["controlnets"] = []
prop_dict["loras"] = []
prop_dict["vae"] = None
prop_dict["strength"] = self.strength
prop_dict["init_image"] = self.init_image
prop_dict["positive_style_prompt"] = None
prop_dict["negative_style_prompt"] = None
prop_dict["refiner_model"] = None
prop_dict["refiner_cfg_scale"] = None
prop_dict["refiner_steps"] = None
prop_dict["refiner_scheduler"] = None
prop_dict["refiner_aesthetic_store"] = None
prop_dict["refiner_start"] = None
prop_dict["imported_app_version"] = self.imported_app_version
return json.dumps(prop_dict)
class InvokeAIMetadataParser:
"""Parses strings with json data to find Invoke AI core metadata properties."""
def __init__(self):
pass
def parse_meta_tag_dream(self, dream_string):
"""Take as input an png metadata json node for the 'dream' field variant from prior to 1.15"""
props = InvokeAIMetadata()
props.imported_app_version = "pre1.15"
seed_match = re.search("-S\\s*(\\d+)", dream_string)
if seed_match is not None:
try:
props.seed = int(seed_match[1])
except ValueError:
props.seed = None
raw_prompt = re.sub("(-S\\s*\\d+)", "", dream_string)
else:
raw_prompt = dream_string
pos_prompt, neg_prompt = self.split_prompt(raw_prompt)
props.positive_prompt = pos_prompt
props.negative_prompt = neg_prompt
return props
def parse_meta_tag_sd_metadata(self, tag_value):
"""Take as input an png metadata json node for the 'sd-metadata' field variant from 1.15 through 2.3.5 post 2"""
props = InvokeAIMetadata()
props.imported_app_version = tag_value.get("app_version")
props.model_name = tag_value.get("model_weights")
img_node = tag_value.get("image")
if img_node is not None:
props.generation_mode = img_node.get("type")
props.width = img_node.get("width")
props.height = img_node.get("height")
props.seed = img_node.get("seed")
props.rand_device = "cuda" # hardcoded since all generations pre 3.0 used cuda random noise instead of cpu
props.cfg_scale = img_node.get("cfg_scale")
props.steps = img_node.get("steps")
props.scheduler = self.map_scheduler(img_node.get("sampler"))
props.strength = img_node.get("strength")
if props.strength is None:
props.strength = img_node.get("strength_steps") # try second name for this property
props.init_image = img_node.get("init_image_path")
if props.init_image is None: # try second name for this property
props.init_image = img_node.get("init_img")
# remove the path info from init_image so if we move the init image, it will be correctly relative in the new location
if props.init_image is not None:
props.init_image = os.path.basename(props.init_image)
raw_prompt = img_node.get("prompt")
if isinstance(raw_prompt, list):
raw_prompt = raw_prompt[0].get("prompt")
props.positive_prompt, props.negative_prompt = self.split_prompt(raw_prompt)
return props
def parse_meta_tag_invokeai(self, tag_value):
"""Take as input an png metadata json node for the 'invokeai' field variant from 3.0.0 beta 1 through 5"""
props = InvokeAIMetadata()
props.imported_app_version = "3.0.0 or later"
props.generation_mode = tag_value.get("type")
if props.generation_mode is not None:
props.generation_mode = props.generation_mode.replace("t2l", "txt2img").replace("l2l", "img2img")
props.width = tag_value.get("width")
props.height = tag_value.get("height")
props.seed = tag_value.get("seed")
props.cfg_scale = tag_value.get("cfg_scale")
props.steps = tag_value.get("steps")
props.scheduler = tag_value.get("scheduler")
props.strength = tag_value.get("strength")
props.positive_prompt = tag_value.get("positive_conditioning")
props.negative_prompt = tag_value.get("negative_conditioning")
return props
def map_scheduler(self, old_scheduler):
"""Convert the legacy sampler names to matching 3.0 schedulers"""
if old_scheduler is None:
return None
match (old_scheduler):
case "ddim":
return "ddim"
case "plms":
return "pnmd"
case "k_lms":
return "lms"
case "k_dpm_2":
return "kdpm_2"
case "k_dpm_2_a":
return "kdpm_2_a"
case "dpmpp_2":
return "dpmpp_2s"
case "k_dpmpp_2":
return "dpmpp_2m"
case "k_dpmpp_2_a":
return None # invalid, in 2.3.x, selecting this sample would just fallback to last run or plms if new session
case "k_euler":
return "euler"
case "k_euler_a":
return "euler_a"
case "k_heun":
return "heun"
return None
def split_prompt(self, raw_prompt: str):
"""Split the unified prompt strings by extracting all negative prompt blocks out into the negative prompt."""
if raw_prompt is None:
return "", ""
raw_prompt_search = raw_prompt.replace("\r", "").replace("\n", "")
matches = re.findall(r"\[(.+?)\]", raw_prompt_search)
if len(matches) > 0:
negative_prompt = ""
if len(matches) == 1:
negative_prompt = matches[0].strip().strip(",")
else:
for match in matches:
negative_prompt += f"({match.strip().strip(',')})"
positive_prompt = re.sub(r"(\[.+?\])", "", raw_prompt_search).strip()
else:
positive_prompt = raw_prompt_search.strip()
negative_prompt = ""
return positive_prompt, negative_prompt
class DatabaseMapper:
"""Class to abstract database functionality."""
def __init__(self, database_path, database_backup_dir):
self.database_path = database_path
self.database_backup_dir = database_backup_dir
self.connection = None
self.cursor = None
def connect(self):
"""Open connection to the database."""
self.connection = sqlite3.connect(self.database_path)
self.cursor = self.connection.cursor()
def get_board_names(self):
"""Get a list of the current board names from the database."""
sql_get_board_name = "SELECT board_name FROM boards"
self.cursor.execute(sql_get_board_name)
rows = self.cursor.fetchall()
return [row[0] for row in rows]
def does_image_exist(self, image_name):
"""Check database if a image name already exists and return a boolean."""
sql_get_image_by_name = f"SELECT image_name FROM images WHERE image_name='{image_name}'"
self.cursor.execute(sql_get_image_by_name)
rows = self.cursor.fetchall()
return True if len(rows) > 0 else False
def add_new_image_to_database(self, filename, width, height, metadata, modified_date_string):
"""Add an image to the database."""
sql_add_image = f"""INSERT INTO images (image_name, image_origin, image_category, width, height, session_id, node_id, metadata, is_intermediate, created_at, updated_at)
VALUES ('{filename}', 'internal', 'general', {width}, {height}, null, null, '{metadata}', 0, '{modified_date_string}', '{modified_date_string}')"""
self.cursor.execute(sql_add_image)
self.connection.commit()
def get_board_id_with_create(self, board_name):
"""Get the board id for supplied name, and create the board if one does not exist."""
sql_find_board = f"SELECT board_id FROM boards WHERE board_name='{board_name}' COLLATE NOCASE"
self.cursor.execute(sql_find_board)
rows = self.cursor.fetchall()
if len(rows) > 0:
return rows[0][0]
else:
board_date_string = datetime.datetime.utcnow().date().isoformat()
new_board_id = str(uuid.uuid4())
sql_insert_board = f"INSERT INTO boards (board_id, board_name, created_at, updated_at) VALUES ('{new_board_id}', '{board_name}', '{board_date_string}', '{board_date_string}')"
self.cursor.execute(sql_insert_board)
self.connection.commit()
return new_board_id
def add_image_to_board(self, filename, board_id):
"""Add an image mapping to a board."""
add_datetime_str = datetime.datetime.utcnow().isoformat()
sql_add_image_to_board = f"""INSERT INTO board_images (board_id, image_name, created_at, updated_at)
VALUES ('{board_id}', '{filename}', '{add_datetime_str}', '{add_datetime_str}')"""
self.cursor.execute(sql_add_image_to_board)
self.connection.commit()
def disconnect(self):
"""Disconnect from the db, cleaning up connections and cursors."""
if self.cursor is not None:
self.cursor.close()
if self.connection is not None:
self.connection.close()
def backup(self, timestamp_string):
"""Take a backup of the database."""
if not os.path.exists(self.database_backup_dir):
print(f"Database backup directory {self.database_backup_dir} does not exist -> creating...", end="")
os.makedirs(self.database_backup_dir)
print("Done!")
database_backup_path = os.path.join(self.database_backup_dir, f"backup-{timestamp_string}-invokeai.db")
print(f"Making DB Backup at {database_backup_path}...", end="")
shutil.copy2(self.database_path, database_backup_path)
print("Done!")
class MediaImportProcessor:
"""Containing class for script functionality."""
def __init__(self):
pass
board_name_id_map = {}
def get_import_file_list(self):
"""Ask the user for the import folder and scan for the list of files to return."""
while True:
default = ""
while True:
import_dir = os.path.expanduser(
prompt(
"Inputs: Specify absolute path containing InvokeAI .png images to import: ",
completer=PathCompleter(expanduser=True, only_directories=True),
default=default,
)
)
if len(import_dir) > 0 and Path(import_dir).is_dir():
break
default = import_dir
recurse_directories = (
(prompt("Include files from subfolders recursively [yN]? ").strip() or "N").upper().startswith("N")
)
if recurse_directories:
is_recurse = False
matching_file_list = glob.glob(import_dir + "/*.png", recursive=False)
else:
is_recurse = True
matching_file_list = glob.glob(import_dir + "/**/*.png", recursive=True)
if len(matching_file_list) > 0:
return import_dir, is_recurse, matching_file_list
else:
print(f"The specific path {import_dir} exists, but does not contain .png files!")
def get_file_details(self, filepath):
"""Retrieve the embedded metedata fields and dimensions from an image file."""
with PIL.Image.open(filepath) as img:
img.load()
png_width, png_height = img.size
img_info = img.info
return img_info, png_width, png_height
def select_board_option(self, board_names, timestamp_string):
"""Allow the user to choose how a board is selected for imported files."""
while True:
print("\r\nOptions for board selection for imported images:")
print(f"1) Select an existing board name. (found {len(board_names)})")
print("2) Specify a board name to create/add to.")
print("3) Create/add to board named 'IMPORT'.")
print(
f"4) Create/add to board named 'IMPORT' with the current datetime string appended (.e.g IMPORT_{timestamp_string})."
)
print(
"5) Create/add to board named 'IMPORT' with a the original file app_version appended (.e.g IMPORT_2.2.5)."
)
input_option = input("Specify desired board option: ")
match (input_option):
case "1":
if len(board_names) < 1:
print("\r\nThere are no existing board names to choose from. Select another option!")
continue
board_name = self.select_item_from_list(
board_names, "board name", True, "Cancel, go back and choose a different board option."
)
if board_name is not None:
return board_name
case "2":
while True:
board_name = input("Specify new/existing board name: ")
if board_name:
return board_name
case "3":
return "IMPORT"
case "4":
return f"IMPORT_{timestamp_string}"
case "5":
return "IMPORT_APPVERSION"
def select_item_from_list(self, items, entity_name, allow_cancel, cancel_string):
"""A general function to render a list of items to select in the console, prompt the user for a selection and ensure a valid entry is selected."""
print(f"Select a {entity_name.lower()} from the following list:")
index = 1
for item in items:
print(f"{index}) {item}")
index += 1
if allow_cancel:
print(f"{index}) {cancel_string}")
while True:
try:
option_number = int(input("Specify number of selection: "))
except ValueError:
continue
if allow_cancel and option_number == index:
return None
if option_number >= 1 and option_number <= len(items):
return items[option_number - 1]
def import_image(self, filepath: str, board_name_option: str, db_mapper: DatabaseMapper, config: Config):
"""Import a single file by its path"""
parser = InvokeAIMetadataParser()
file_name = os.path.basename(filepath)
file_destination_path = os.path.join(config.outputs_path, file_name)
print("===============================================================================")
print(f"Importing {filepath}")
# check destination to see if the file was previously imported
if os.path.exists(file_destination_path):
print("File already exists in the destination, skipping!")
ImportStats.count_skipped_file_exists += 1
return
# check if file name is already referenced in the database
if db_mapper.does_image_exist(file_name):
print("A reference to a file with this name already exists in the database, skipping!")
ImportStats.count_skipped_db_exists += 1
return
# load image info and dimensions
img_info, png_width, png_height = self.get_file_details(filepath)
# parse metadata
destination_needs_meta_update = True
log_version_note = "(Unknown)"
if "invokeai_metadata" in img_info:
# for the latest, we will just re-emit the same json, no need to parse/modify
converted_field = None
latest_json_string = img_info.get("invokeai_metadata")
log_version_note = "3.0.0+"
destination_needs_meta_update = False
else:
if "sd-metadata" in img_info:
converted_field = parser.parse_meta_tag_sd_metadata(json.loads(img_info.get("sd-metadata")))
elif "invokeai" in img_info:
converted_field = parser.parse_meta_tag_invokeai(json.loads(img_info.get("invokeai")))
elif "dream" in img_info:
converted_field = parser.parse_meta_tag_dream(img_info.get("dream"))
elif "Dream" in img_info:
converted_field = parser.parse_meta_tag_dream(img_info.get("Dream"))
else:
converted_field = InvokeAIMetadata()
destination_needs_meta_update = False
print("File does not have metadata from known Invoke AI versions, add only, no update!")
# use the loaded img dimensions if the metadata didnt have them
if converted_field.width is None:
converted_field.width = png_width
if converted_field.height is None:
converted_field.height = png_height
log_version_note = converted_field.imported_app_version if converted_field else "NoVersion"
log_version_note = log_version_note or "NoVersion"
latest_json_string = converted_field.to_json()
print(f"From Invoke AI Version {log_version_note} with dimensions {png_width} x {png_height}.")
# if metadata needs update, then update metdata and copy in one shot
if destination_needs_meta_update:
print("Updating metadata while copying...", end="")
self.update_file_metadata_while_copying(
filepath, file_destination_path, "invokeai_metadata", latest_json_string
)
print("Done!")
else:
print("No metadata update necessary, copying only...", end="")
shutil.copy2(filepath, file_destination_path)
print("Done!")
# create thumbnail
print("Creating thumbnail...", end="")
thumbnail_path = os.path.join(config.thumbnail_path, os.path.splitext(file_name)[0]) + ".webp"
thumbnail_size = 256, 256
with PIL.Image.open(filepath) as source_image:
source_image.thumbnail(thumbnail_size)
source_image.save(thumbnail_path, "webp")
print("Done!")
# finalize the dynamic board name if there is an APPVERSION token in it.
if converted_field is not None:
board_name = board_name_option.replace("APPVERSION", converted_field.imported_app_version or "NoVersion")
else:
board_name = board_name_option.replace("APPVERSION", "Latest")
# maintain a map of alrady created/looked up ids to avoid DB queries
print("Finding/Creating board...", end="")
if board_name in self.board_name_id_map:
board_id = self.board_name_id_map[board_name]
else:
board_id = db_mapper.get_board_id_with_create(board_name)
self.board_name_id_map[board_name] = board_id
print("Done!")
# add image to db
print("Adding image to database......", end="")
modified_time = datetime.datetime.utcfromtimestamp(os.path.getmtime(filepath))
db_mapper.add_new_image_to_database(file_name, png_width, png_height, latest_json_string, modified_time)
print("Done!")
# add image to board
print("Adding image to board......", end="")
db_mapper.add_image_to_board(file_name, board_id)
print("Done!")
ImportStats.count_imported += 1
if log_version_note in ImportStats.count_imported_by_version:
ImportStats.count_imported_by_version[log_version_note] += 1
else:
ImportStats.count_imported_by_version[log_version_note] = 1
def update_file_metadata_while_copying(self, filepath, file_destination_path, tag_name, tag_value):
"""Perform a metadata update with save to a new destination which accomplishes a copy while updating metadata."""
with PIL.Image.open(filepath) as target_image:
existing_img_info = target_image.info
metadata = PIL.PngImagePlugin.PngInfo()
# re-add any existing invoke ai tags unless they are the one we are trying to add
for key in existing_img_info:
if key != tag_name and key in ("dream", "Dream", "sd-metadata", "invokeai", "invokeai_metadata"):
metadata.add_text(key, existing_img_info[key])
metadata.add_text(tag_name, tag_value)
target_image.save(file_destination_path, pnginfo=metadata)
def process(self):
"""Begin main processing."""
print("===============================================================================")
print("This script will import images generated by earlier versions of")
print("InvokeAI into the currently installed root directory:")
print(f" {app_config.root_path}")
print("If this is not what you want to do, type ctrl-C now to cancel.")
# load config
print("===============================================================================")
print("= Configuration & Settings")
config = Config()
config.find_and_load()
db_mapper = DatabaseMapper(config.database_path, config.database_backup_dir)
db_mapper.connect()
import_dir, is_recurse, import_file_list = self.get_import_file_list()
ImportStats.count_source_files = len(import_file_list)
board_names = db_mapper.get_board_names()
board_name_option = self.select_board_option(board_names, config.TIMESTAMP_STRING)
print("\r\n===============================================================================")
print("= Import Settings Confirmation")
print()
print(f"Database File Path : {config.database_path}")
print(f"Outputs/Images Directory : {config.outputs_path}")
print(f"Import Image Source Directory : {import_dir}")
print(f" Recurse Source SubDirectories : {'Yes' if is_recurse else 'No'}")
print(f"Count of .png file(s) found : {len(import_file_list)}")
print(f"Board name option specified : {board_name_option}")
print(f"Database backup will be taken at : {config.database_backup_dir}")
print("\r\nNotes about the import process:")
print("- Source image files will not be modified, only copied to the outputs directory.")
print("- If the same file name already exists in the destination, the file will be skipped.")
print("- If the same file name already has a record in the database, the file will be skipped.")
print("- Invoke AI metadata tags will be updated/written into the imported copy only.")
print(
"- On the imported copy, only Invoke AI known tags (latest and legacy) will be retained (dream, sd-metadata, invokeai, invokeai_metadata)"
)
print(
"- A property 'imported_app_version' will be added to metadata that can be viewed in the UI's metadata viewer."
)
print(
"- The new 3.x InvokeAI outputs folder structure is flat so recursively found source imges will all be placed into the single outputs/images folder."
)
while True:
should_continue = prompt("\nDo you wish to continue with the import [Yn] ? ").lower() or "y"
if should_continue == "n":
print("\r\nCancelling Import")
return
elif should_continue == "y":
print()
break
db_mapper.backup(config.TIMESTAMP_STRING)
print()
ImportStats.time_start = datetime.datetime.utcnow()
for filepath in import_file_list:
try:
self.import_image(filepath, board_name_option, db_mapper, config)
except sqlite3.Error as sql_ex:
print(f"A database related exception was found processing {filepath}, will continue to next file. ")
print("Exception detail:")
print(sql_ex)
ImportStats.count_file_errors += 1
except Exception as ex:
print(f"Exception processing {filepath}, will continue to next file. ")
print("Exception detail:")
print(ex)
ImportStats.count_file_errors += 1
print("\r\n===============================================================================")
print(f"= Import Complete - Elpased Time: {ImportStats.get_elapsed_time_string()}")
print()
print(f"Source File(s) : {ImportStats.count_source_files}")
print(f"Total Imported : {ImportStats.count_imported}")
print(f"Skipped b/c file already exists on disk : {ImportStats.count_skipped_file_exists}")
print(f"Skipped b/c file already exists in db : {ImportStats.count_skipped_db_exists}")
print(f"Errors during import : {ImportStats.count_file_errors}")
if ImportStats.count_imported > 0:
print("\r\nBreakdown of imported files by version:")
for key, version in ImportStats.count_imported_by_version.items():
print(f" {key:20} : {version}")
def main():
try:
processor = MediaImportProcessor()
processor.process()
except KeyboardInterrupt:
print("\r\n\r\nUser cancelled execution.")
if __name__ == "__main__":
main()

View File

@ -28,7 +28,6 @@ from npyscreen import widget
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.install.model_install_backend import (
ModelInstallList,
InstallSelections,
ModelInstall,
SchedulerPredictionType,
@ -41,12 +40,12 @@ from invokeai.frontend.install.widgets import (
SingleSelectColumns,
TextBox,
BufferBox,
FileBox,
set_min_terminal_size,
select_stable_diffusion_config_file,
CyclingForm,
MIN_COLS,
MIN_LINES,
WindowTooSmallException,
)
from invokeai.app.services.config import InvokeAIAppConfig
@ -156,7 +155,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
BufferBox,
name="Log Messages",
editable=False,
max_height=15,
max_height=6,
)
self.nextrely += 1
@ -693,7 +692,11 @@ def select_and_download_models(opt: Namespace):
# needed to support the probe() method running under a subprocess
torch.multiprocessing.set_start_method("spawn")
set_min_terminal_size(MIN_COLS, MIN_LINES)
if not set_min_terminal_size(MIN_COLS, MIN_LINES):
raise WindowTooSmallException(
"Could not increase terminal size. Try running again with a larger window or smaller font size."
)
installApp = AddModelApplication(opt)
try:
installApp.run()
@ -787,6 +790,8 @@ def main():
curses.echo()
curses.endwin()
logger.info("Goodbye! Come back soon.")
except WindowTooSmallException as e:
logger.error(str(e))
except widget.NotEnoughSpaceForWidget as e:
if str(e).startswith("Height of 1 allocated"):
logger.error("Insufficient vertical space for the interface. Please make your window taller and try again")

View File

@ -21,31 +21,40 @@ MIN_COLS = 130
MIN_LINES = 38
class WindowTooSmallException(Exception):
pass
# -------------------------------------
def set_terminal_size(columns: int, lines: int):
ts = get_terminal_size()
width = max(columns, ts.columns)
height = max(lines, ts.lines)
def set_terminal_size(columns: int, lines: int) -> bool:
OS = platform.uname().system
if OS == "Windows":
pass
# not working reliably - ask user to adjust the window
# _set_terminal_size_powershell(width,height)
elif OS in ["Darwin", "Linux"]:
_set_terminal_size_unix(width, height)
screen_ok = False
while not screen_ok:
ts = get_terminal_size()
width = max(columns, ts.columns)
height = max(lines, ts.lines)
# check whether it worked....
ts = get_terminal_size()
pause = False
if ts.columns < columns:
print("\033[1mThis window is too narrow for the user interface.\033[0m")
pause = True
if ts.lines < lines:
print("\033[1mThis window is too short for the user interface.\033[0m")
pause = True
if pause:
input("Maximize the window then press any key to continue..")
if OS == "Windows":
pass
# not working reliably - ask user to adjust the window
# _set_terminal_size_powershell(width,height)
elif OS in ["Darwin", "Linux"]:
_set_terminal_size_unix(width, height)
# check whether it worked....
ts = get_terminal_size()
if ts.columns < columns or ts.lines < lines:
print(
f"\033[1mThis window is too small for the interface. InvokeAI requires {columns}x{lines} (w x h) characters, but window is {ts.columns}x{ts.lines}\033[0m"
)
resp = input(
"Maximize the window and/or decrease the font size then press any key to continue. Type [Q] to give up.."
)
if resp.upper().startswith("Q"):
break
else:
screen_ok = True
return screen_ok
def _set_terminal_size_powershell(width: int, height: int):
@ -80,14 +89,14 @@ def _set_terminal_size_unix(width: int, height: int):
sys.stdout.flush()
def set_min_terminal_size(min_cols: int, min_lines: int):
def set_min_terminal_size(min_cols: int, min_lines: int) -> bool:
# make sure there's enough room for the ui
term_cols, term_lines = get_terminal_size()
if term_cols >= min_cols and term_lines >= min_lines:
return
return True
cols = max(term_cols, min_cols)
lines = max(term_lines, min_lines)
set_terminal_size(cols, lines)
return set_terminal_size(cols, lines)
class IntSlider(npyscreen.Slider):
@ -164,7 +173,7 @@ class FloatSlider(npyscreen.Slider):
class FloatTitleSlider(npyscreen.TitleText):
_entry_type = FloatSlider
_entry_type = npyscreen.Slider
class SelectColumnBase:

File diff suppressed because one or more lines are too long

View File

@ -1,4 +1,4 @@
import{B as m,g7 as Je,A as y,a5 as Ka,g8 as Xa,af as va,aj as d,g9 as b,ga as t,gb as Ya,gc as h,gd as ua,ge as Ja,gf as Qa,aL as Za,gg as et,ad as rt,gh as at}from"./index-de589048.js";import{s as fa,n as o,t as tt,o as ha,p as ot,q as ma,v as ga,w as ya,x as it,y as Sa,z as pa,A as xr,B as nt,D as lt,E as st,F as xa,G as $a,H as ka,J as dt,K as _a,L as ct,M as bt,N as vt,O as ut,Q as wa,R as ft,S as ht,T as mt,U as gt,V as yt,W as St,e as pt,X as xt}from"./menu-11348abc.js";var za=String.raw,Ca=za`
import{B as m,g7 as Je,A as y,a5 as Ka,g8 as Xa,af as va,aj as d,g9 as b,ga as t,gb as Ya,gc as h,gd as ua,ge as Ja,gf as Qa,aL as Za,gg as et,ad as rt,gh as at}from"./index-dd054634.js";import{s as fa,n as o,t as tt,o as ha,p as ot,q as ma,v as ga,w as ya,x as it,y as Sa,z as pa,A as xr,B as nt,D as lt,E as st,F as xa,G as $a,H as ka,J as dt,K as _a,L as ct,M as bt,N as vt,O as ut,Q as wa,R as ft,S as ht,T as mt,U as gt,V as yt,W as St,e as pt,X as xt}from"./menu-b42141e3.js";var za=String.raw,Ca=za`
:root,
:host {
--chakra-vh: 100vh;

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -12,7 +12,7 @@
margin: 0;
}
</style>
<script type="module" crossorigin src="./assets/index-de589048.js"></script>
<script type="module" crossorigin src="./assets/index-dd054634.js"></script>
</head>
<body dir="ltr">

View File

@ -96,7 +96,8 @@ export type AppFeature =
| 'consoleLogging'
| 'dynamicPrompting'
| 'batches'
| 'syncModels';
| 'syncModels'
| 'multiselect';
/**
* A disable-able Stable Diffusion feature

View File

@ -9,6 +9,7 @@ import { useListImagesQuery } from 'services/api/endpoints/images';
import { ImageDTO } from 'services/api/types';
import { selectionChanged } from '../store/gallerySlice';
import { imagesSelectors } from 'services/api/util';
import { useFeatureStatus } from '../../system/hooks/useFeatureStatus';
const selector = createSelector(
[stateSelector, selectListImagesBaseQueryArgs],
@ -33,11 +34,18 @@ export const useMultiselect = (imageDTO?: ImageDTO) => {
}),
});
const isMultiSelectEnabled = useFeatureStatus('multiselect').isFeatureEnabled;
const handleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
if (!imageDTO) {
return;
}
if (!isMultiSelectEnabled) {
dispatch(selectionChanged([imageDTO]));
return;
}
if (e.shiftKey) {
const rangeEndImageName = imageDTO.image_name;
const lastSelectedImage = selection[selection.length - 1]?.image_name;
@ -71,7 +79,7 @@ export const useMultiselect = (imageDTO?: ImageDTO) => {
dispatch(selectionChanged([imageDTO]));
}
},
[dispatch, imageDTO, imageDTOs, selection]
[dispatch, imageDTO, imageDTOs, selection, isMultiSelectEnabled]
);
const isSelected = useMemo(

View File

@ -31,7 +31,7 @@ const ParamLoraCollapse = () => {
}
return (
<IAICollapse label={'LoRA'} activeLabel={activeLabel}>
<IAICollapse label="LoRA" activeLabel={activeLabel}>
<Flex sx={{ flexDir: 'column', gap: 2 }}>
<ParamLoRASelect />
<ParamLoraList />

View File

@ -1,3 +1,4 @@
import { Divider } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
@ -8,20 +9,21 @@ import ParamLora from './ParamLora';
const selector = createSelector(
stateSelector,
({ lora }) => {
const { loras } = lora;
return { loras };
return { lorasArray: map(lora.loras) };
},
defaultSelectorOptions
);
const ParamLoraList = () => {
const { loras } = useAppSelector(selector);
const { lorasArray } = useAppSelector(selector);
return (
<>
{map(loras, (lora) => (
<ParamLora key={lora.model_name} lora={lora} />
{lorasArray.map((lora, i) => (
<>
{i > 0 && <Divider key={`${lora.model_name}-divider`} pt={1} />}
<ParamLora key={lora.model_name} lora={lora} />
</>
))}
</>
);

View File

@ -9,7 +9,6 @@ import {
CLIP_SKIP,
LORA_LOADER,
MAIN_MODEL_LOADER,
ONNX_MODEL_LOADER,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING,
POSITIVE_CONDITIONING,
@ -36,15 +35,11 @@ export const addLoRAsToGraph = (
| undefined;
if (loraCount > 0) {
// Remove MAIN_MODEL_LOADER unet connection to feed it to LoRAs
// Remove modelLoaderNodeId unet connection to feed it to LoRAs
graph.edges = graph.edges.filter(
(e) =>
!(
e.source.node_id === MAIN_MODEL_LOADER &&
['unet'].includes(e.source.field)
) &&
!(
e.source.node_id === ONNX_MODEL_LOADER &&
e.source.node_id === modelLoaderNodeId &&
['unet'].includes(e.source.field)
)
);

View File

@ -0,0 +1,212 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { forEach, size } from 'lodash-es';
import {
MetadataAccumulatorInvocation,
SDXLLoraLoaderInvocation,
} from 'services/api/types';
import {
LORA_LOADER,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING,
POSITIVE_CONDITIONING,
SDXL_MODEL_LOADER,
} from './constants';
export const addSDXLLoRAsToGraph = (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string,
modelLoaderNodeId: string = SDXL_MODEL_LOADER
): void => {
/**
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
* or to the inference/conditioning nodes.
*
* So we need to inject a LoRA chain into the graph.
*/
const { loras } = state.lora;
const loraCount = size(loras);
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (loraCount > 0) {
// Remove modelLoaderNodeId unet/clip/clip2 connections to feed it to LoRAs
graph.edges = graph.edges.filter(
(e) =>
!(
e.source.node_id === modelLoaderNodeId &&
['unet'].includes(e.source.field)
) &&
!(
e.source.node_id === modelLoaderNodeId &&
['clip'].includes(e.source.field)
) &&
!(
e.source.node_id === modelLoaderNodeId &&
['clip2'].includes(e.source.field)
)
);
}
// we need to remember the last lora so we can chain from it
let lastLoraNodeId = '';
let currentLoraIndex = 0;
forEach(loras, (lora) => {
const { model_name, base_model, weight } = lora;
const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`;
const loraLoaderNode: SDXLLoraLoaderInvocation = {
type: 'sdxl_lora_loader',
id: currentLoraNodeId,
is_intermediate: true,
lora: { model_name, base_model },
weight,
};
// add the lora to the metadata accumulator
if (metadataAccumulator) {
metadataAccumulator.loras.push({
lora: { model_name, base_model },
weight,
});
}
// add to graph
graph.nodes[currentLoraNodeId] = loraLoaderNode;
if (currentLoraIndex === 0) {
// first lora = start the lora chain, attach directly to model loader
graph.edges.push({
source: {
node_id: modelLoaderNodeId,
field: 'unet',
},
destination: {
node_id: currentLoraNodeId,
field: 'unet',
},
});
graph.edges.push({
source: {
node_id: modelLoaderNodeId,
field: 'clip',
},
destination: {
node_id: currentLoraNodeId,
field: 'clip',
},
});
graph.edges.push({
source: {
node_id: modelLoaderNodeId,
field: 'clip2',
},
destination: {
node_id: currentLoraNodeId,
field: 'clip2',
},
});
} else {
// we are in the middle of the lora chain, instead connect to the previous lora
graph.edges.push({
source: {
node_id: lastLoraNodeId,
field: 'unet',
},
destination: {
node_id: currentLoraNodeId,
field: 'unet',
},
});
graph.edges.push({
source: {
node_id: lastLoraNodeId,
field: 'clip',
},
destination: {
node_id: currentLoraNodeId,
field: 'clip',
},
});
graph.edges.push({
source: {
node_id: lastLoraNodeId,
field: 'clip2',
},
destination: {
node_id: currentLoraNodeId,
field: 'clip2',
},
});
}
if (currentLoraIndex === loraCount - 1) {
// final lora, end the lora chain - we need to connect up to inference and conditioning nodes
graph.edges.push({
source: {
node_id: currentLoraNodeId,
field: 'unet',
},
destination: {
node_id: baseNodeId,
field: 'unet',
},
});
graph.edges.push({
source: {
node_id: currentLoraNodeId,
field: 'clip',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip',
},
});
graph.edges.push({
source: {
node_id: currentLoraNodeId,
field: 'clip',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip',
},
});
graph.edges.push({
source: {
node_id: currentLoraNodeId,
field: 'clip2',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip2',
},
});
graph.edges.push({
source: {
node_id: currentLoraNodeId,
field: 'clip2',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip2',
},
});
}
// increment the lora for the next one in the chain
lastLoraNodeId = currentLoraNodeId;
currentLoraIndex += 1;
});
};

View File

@ -22,6 +22,7 @@ import {
SDXL_LATENTS_TO_LATENTS,
SDXL_MODEL_LOADER,
} from './constants';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
/**
* Builds the Image to Image tab graph.
@ -364,6 +365,8 @@ export const buildLinearSDXLImageToImageGraph = (
},
});
addSDXLLoRAsToGraph(state, graph, SDXL_LATENTS_TO_LATENTS, SDXL_MODEL_LOADER);
// Add Refiner if enabled
if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_LATENTS_TO_LATENTS);

View File

@ -4,6 +4,7 @@ import { NonNullableGraph } from 'features/nodes/types/types';
import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -246,6 +247,8 @@ export const buildLinearSDXLTextToImageGraph = (
},
});
addSDXLLoRAsToGraph(state, graph, SDXL_TEXT_TO_LATENTS, SDXL_MODEL_LOADER);
// Add Refiner if enabled
if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_TEXT_TO_LATENTS);

View File

@ -4,6 +4,7 @@ import ProcessButtons from 'features/parameters/components/ProcessButtons/Proces
import ParamSDXLPromptArea from './ParamSDXLPromptArea';
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
import SDXLImageToImageTabCoreParameters from './SDXLImageToImageTabCoreParameters';
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
const SDXLImageToImageTabParameters = () => {
return (
@ -12,6 +13,7 @@ const SDXLImageToImageTabParameters = () => {
<ProcessButtons />
<SDXLImageToImageTabCoreParameters />
<ParamSDXLRefinerCollapse />
<ParamLoraCollapse />
<ParamDynamicPromptsCollapse />
<ParamNoiseCollapse />
</>

View File

@ -4,6 +4,7 @@ import ProcessButtons from 'features/parameters/components/ProcessButtons/Proces
import TextToImageTabCoreParameters from 'features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters';
import ParamSDXLPromptArea from './ParamSDXLPromptArea';
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
const SDXLTextToImageTabParameters = () => {
return (
@ -12,6 +13,7 @@ const SDXLTextToImageTabParameters = () => {
<ProcessButtons />
<TextToImageTabCoreParameters />
<ParamSDXLRefinerCollapse />
<ParamLoraCollapse />
<ParamDynamicPromptsCollapse />
<ParamNoiseCollapse />
</>

View File

@ -4,6 +4,7 @@ import {
ASSETS_CATEGORIES,
BoardId,
IMAGE_CATEGORIES,
IMAGE_LIMIT,
} from 'features/gallery/store/types';
import { keyBy } from 'lodash';
import { ApiFullTagDescription, LIST_TAG, api } from '..';
@ -167,7 +168,14 @@ export const imagesApi = api.injectEndpoints({
},
};
},
invalidatesTags: (result, error, imageDTOs) => [],
invalidatesTags: (result, error, { imageDTOs }) => {
// for now, assume bulk delete is all on one board
const boardId = imageDTOs[0]?.board_id;
return [
{ type: 'BoardImagesTotal', id: boardId ?? 'none' },
{ type: 'BoardAssetsTotal', id: boardId ?? 'none' },
];
},
async onQueryStarted({ imageDTOs }, { dispatch, queryFulfilled }) {
/**
* Cache changes for `deleteImages`:
@ -889,18 +897,25 @@ export const imagesApi = api.injectEndpoints({
board_id,
},
}),
invalidatesTags: (result, error, { board_id }) => [
// update the destination board
{ type: 'Board', id: board_id ?? 'none' },
// update old board totals
{ type: 'BoardImagesTotal', id: board_id ?? 'none' },
{ type: 'BoardAssetsTotal', id: board_id ?? 'none' },
// update the no_board totals
{ type: 'BoardImagesTotal', id: 'none' },
{ type: 'BoardAssetsTotal', id: 'none' },
],
invalidatesTags: (result, error, { imageDTOs, board_id }) => {
//assume all images are being moved from one board for now
const oldBoardId = imageDTOs[0]?.board_id;
return [
// update the destination board
{ type: 'Board', id: board_id ?? 'none' },
// update new board totals
{ type: 'BoardImagesTotal', id: board_id ?? 'none' },
{ type: 'BoardAssetsTotal', id: board_id ?? 'none' },
// update old board totals
{ type: 'BoardImagesTotal', id: oldBoardId ?? 'none' },
{ type: 'BoardAssetsTotal', id: oldBoardId ?? 'none' },
// update the no_board totals
{ type: 'BoardImagesTotal', id: 'none' },
{ type: 'BoardAssetsTotal', id: 'none' },
];
},
async onQueryStarted(
{ board_id, imageDTOs },
{ board_id: new_board_id, imageDTOs },
{ dispatch, queryFulfilled, getState }
) {
try {
@ -920,7 +935,7 @@ export const imagesApi = api.injectEndpoints({
'getImageDTO',
image_name,
(draft) => {
draft.board_id = board_id;
draft.board_id = new_board_id;
}
)
);
@ -946,7 +961,7 @@ export const imagesApi = api.injectEndpoints({
);
const queryArgs = {
board_id,
board_id: new_board_id,
categories,
};
@ -954,23 +969,24 @@ export const imagesApi = api.injectEndpoints({
queryArgs
)(getState());
const { data: total } = IMAGE_CATEGORIES.includes(
const { data: previousTotal } = IMAGE_CATEGORIES.includes(
imageDTO.image_category
)
? boardsApi.endpoints.getBoardImagesTotal.select(
imageDTO.board_id ?? 'none'
new_board_id ?? 'none'
)(getState())
: boardsApi.endpoints.getBoardAssetsTotal.select(
imageDTO.board_id ?? 'none'
new_board_id ?? 'none'
)(getState());
const isCacheFullyPopulated =
currentCache.data && currentCache.data.ids.length >= (total ?? 0);
currentCache.data &&
currentCache.data.ids.length >= (previousTotal ?? 0);
const isInDateRange = getIsImageInDateRange(
currentCache.data,
imageDTO
);
const isInDateRange =
(previousTotal || 0) >= IMAGE_LIMIT
? getIsImageInDateRange(currentCache.data, imageDTO)
: true;
if (isCacheFullyPopulated || isInDateRange) {
// *upsert* to $cache
@ -981,7 +997,7 @@ export const imagesApi = api.injectEndpoints({
(draft) => {
imagesAdapter.upsertOne(draft, {
...imageDTO,
board_id,
board_id: new_board_id,
});
}
)
@ -1097,10 +1113,10 @@ export const imagesApi = api.injectEndpoints({
const isCacheFullyPopulated =
currentCache.data && currentCache.data.ids.length >= (total ?? 0);
const isInDateRange = getIsImageInDateRange(
currentCache.data,
imageDTO
);
const isInDateRange =
(total || 0) >= IMAGE_LIMIT
? getIsImageInDateRange(currentCache.data, imageDTO)
: true;
if (isCacheFullyPopulated || isInDateRange) {
// *upsert* to $cache
@ -1111,7 +1127,7 @@ export const imagesApi = api.injectEndpoints({
(draft) => {
imagesAdapter.upsertOne(draft, {
...imageDTO,
board_id: undefined,
board_id: 'none',
});
}
)

View File

@ -1443,7 +1443,7 @@ export type components = {
* @description The nodes in this graph
*/
nodes?: {
[key: string]: (components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
[key: string]: (components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageLuminosityAdjustmentInvocation"] | components["schemas"]["ImageSaturationAdjustmentInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
};
/**
* Edges
@ -1486,7 +1486,7 @@ export type components = {
* @description The results of node executions
*/
results: {
[key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["MetadataAccumulatorOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["ONNXModelLoaderOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined;
[key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["MetadataAccumulatorOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["ONNXModelLoaderOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined;
};
/**
* Errors
@ -1904,6 +1904,40 @@ export type components = {
*/
image_name: string;
};
/**
* ImageHueAdjustmentInvocation
* @description Adjusts the Hue of an image.
*/
ImageHueAdjustmentInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default img_hue_adjust
* @enum {string}
*/
type?: "img_hue_adjust";
/**
* Image
* @description The image to adjust
*/
image?: components["schemas"]["ImageField"];
/**
* Hue
* @description The degrees by which to rotate the hue, 0-360
* @default 0
*/
hue?: number;
};
/**
* ImageInverseLerpInvocation
* @description Inverse linear interpolation of all pixels of an image
@ -1984,6 +2018,40 @@ export type components = {
*/
max?: number;
};
/**
* ImageLuminosityAdjustmentInvocation
* @description Adjusts the Luminosity (Value) of an image.
*/
ImageLuminosityAdjustmentInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default img_luminosity_adjust
* @enum {string}
*/
type?: "img_luminosity_adjust";
/**
* Image
* @description The image to adjust
*/
image?: components["schemas"]["ImageField"];
/**
* Luminosity
* @description The factor by which to adjust the luminosity (value)
* @default 1
*/
luminosity?: number;
};
/**
* ImageMetadata
* @description An image's generation metadata
@ -2239,6 +2307,40 @@ export type components = {
*/
resample_mode?: "nearest" | "box" | "bilinear" | "hamming" | "bicubic" | "lanczos";
};
/**
* ImageSaturationAdjustmentInvocation
* @description Adjusts the Saturation of an image.
*/
ImageSaturationAdjustmentInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default img_saturation_adjust
* @enum {string}
*/
type?: "img_saturation_adjust";
/**
* Image
* @description The image to adjust
*/
image?: components["schemas"]["ImageField"];
/**
* Saturation
* @description The factor by which to adjust the saturation
* @default 1
*/
saturation?: number;
};
/**
* ImageScaleInvocation
* @description Scales an image by a factor
@ -4912,6 +5014,82 @@ export type components = {
*/
denoising_end?: number;
};
/**
* SDXLLoraLoaderInvocation
* @description Apply selected lora to unet and text_encoder.
*/
SDXLLoraLoaderInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default sdxl_lora_loader
* @enum {string}
*/
type?: "sdxl_lora_loader";
/**
* Lora
* @description Lora model name
*/
lora?: components["schemas"]["LoRAModelField"];
/**
* Weight
* @description With what weight to apply lora
* @default 0.75
*/
weight?: number;
/**
* Unet
* @description UNet model for applying lora
*/
unet?: components["schemas"]["UNetField"];
/**
* Clip
* @description Clip model for applying lora
*/
clip?: components["schemas"]["ClipField"];
/**
* Clip2
* @description Clip2 model for applying lora
*/
clip2?: components["schemas"]["ClipField"];
};
/**
* SDXLLoraLoaderOutput
* @description Model loader output
*/
SDXLLoraLoaderOutput: {
/**
* Type
* @default sdxl_lora_loader_output
* @enum {string}
*/
type?: "sdxl_lora_loader_output";
/**
* Unet
* @description UNet submodel
*/
unet?: components["schemas"]["UNetField"];
/**
* Clip
* @description Tokenizer and text_encoder submodels
*/
clip?: components["schemas"]["ClipField"];
/**
* Clip2
* @description Tokenizer2 and text_encoder2 submodels
*/
clip2?: components["schemas"]["ClipField"];
};
/**
* SDXLModelLoaderInvocation
* @description Loads an sdxl base model, outputting its submodels.
@ -5961,6 +6139,24 @@ export type components = {
*/
image?: components["schemas"]["ImageField"];
};
/**
* ControlNetModelFormat
* @description An enumeration.
* @enum {string}
*/
ControlNetModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionXLModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionOnnxModelFormat
* @description An enumeration.
@ -5973,24 +6169,6 @@ export type components = {
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionXLModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
/**
* ControlNetModelFormat
* @description An enumeration.
* @enum {string}
*/
ControlNetModelFormat: "checkpoint" | "diffusers";
};
responses: never;
parameters: never;
@ -6101,7 +6279,7 @@ export type operations = {
};
requestBody: {
content: {
"application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
"application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageLuminosityAdjustmentInvocation"] | components["schemas"]["ImageSaturationAdjustmentInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
};
};
responses: {
@ -6138,7 +6316,7 @@ export type operations = {
};
requestBody: {
content: {
"application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
"application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageLuminosityAdjustmentInvocation"] | components["schemas"]["ImageSaturationAdjustmentInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
};
};
responses: {

View File

@ -166,6 +166,9 @@ export type OnnxModelLoaderInvocation = TypeReq<
export type LoraLoaderInvocation = TypeReq<
components['schemas']['LoraLoaderInvocation']
>;
export type SDXLLoraLoaderInvocation = TypeReq<
components['schemas']['SDXLLoraLoaderInvocation']
>;
export type MetadataAccumulatorInvocation = TypeReq<
components['schemas']['MetadataAccumulatorInvocation']
>;

View File

@ -1 +1 @@
__version__ = "3.0.2a1"
__version__ = "3.0.2rc1"

View File

@ -77,7 +77,7 @@ dependencies = [
"realesrgan",
"requests~=2.28.2",
"rich~=13.3",
"safetensors~=0.3.0",
"safetensors==0.3.1",
"scikit-image~=0.21.0",
"send2trash",
"test-tube~=0.7.5",
@ -100,7 +100,7 @@ dependencies = [
"dev" = [
"pudb",
]
"test" = ["pytest>6.0.0", "pytest-cov", "black"]
"test" = ["pytest>6.0.0", "pytest-cov", "pytest-datadir", "black"]
"xformers" = [
"xformers~=0.0.19; sys_platform!='darwin'",
"triton; sys_platform=='linux'",
@ -139,6 +139,7 @@ dependencies = [
"invokeai-metadata" = "invokeai.frontend.CLI.sd_metadata:print_metadata"
"invokeai-node-cli" = "invokeai.app.cli_app:invoke_cli"
"invokeai-node-web" = "invokeai.app.api_app:invoke_api"
"invokeai-import-images" = "invokeai.frontend.install.import_images:main"
[project.urls]
"Homepage" = "https://invoke-ai.github.io/InvokeAI/"

View File

@ -0,0 +1,38 @@
from pathlib import Path
import pytest
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend import ModelManager, BaseModelType, ModelType, SubModelType
BASIC_MODEL_NAME = ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main)
VAE_OVERRIDE_MODEL_NAME = ("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main)
@pytest.fixture
def model_manager(datadir) -> ModelManager:
InvokeAIAppConfig.get_config(root=datadir)
return ModelManager(datadir / "configs" / "relative_sub.models.yaml")
def test_get_model_names(model_manager: ModelManager):
names = model_manager.model_names()
assert names[:2] == [BASIC_MODEL_NAME, VAE_OVERRIDE_MODEL_NAME]
def test_get_model_path_for_diffusers(model_manager: ModelManager, datadir: Path):
model_config = model_manager._get_model_config(BASIC_MODEL_NAME[1], BASIC_MODEL_NAME[0], BASIC_MODEL_NAME[2])
top_model_path, is_override = model_manager._get_model_path(model_config)
expected_model_path = datadir / "models" / "sdxl" / "main" / "SDXL base 1_0"
assert top_model_path == expected_model_path
assert not is_override
def test_get_model_path_for_overridden_vae(model_manager: ModelManager, datadir: Path):
model_config = model_manager._get_model_config(
VAE_OVERRIDE_MODEL_NAME[1], VAE_OVERRIDE_MODEL_NAME[0], VAE_OVERRIDE_MODEL_NAME[2]
)
vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae)
expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix"
assert vae_model_path == expected_vae_path
assert is_override

View File

@ -0,0 +1,15 @@
__metadata__:
version: 3.0.0
sdxl/main/SDXL base:
path: sdxl/main/SDXL base 1_0
description: SDXL base v1.0
variant: normal
format: diffusers
sdxl/main/SDXL with VAE:
path: sdxl/main/SDXL base 1_0
description: SDXL with customized VAE
vae: sdxl/vae/sdxl-vae-fp16-fix/
variant: normal
format: diffusers