Compare commits

..

5 Commits

104 changed files with 1306 additions and 2422 deletions

View File

@ -10,11 +10,10 @@ help:
@echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting"
@echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors"
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
@echo "test Run the unit tests."
@echo "frontend-install Install the pnpm modules needed for the front end"
@echo "test" Run the unit tests.
@echo "frontend-install" Install the pnpm modules needed for the front end
@echo "frontend-build Build the frontend in order to run on localhost:9090"
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
@echo "installer-zip Build the installer .zip file for the current version"
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
@ -54,9 +53,6 @@ frontend-build:
frontend-dev:
cd invokeai/frontend/web && pnpm dev
frontend-typegen:
cd invokeai/frontend/web && python ../../../scripts/generate_openapi_schema.py | pnpm typegen
# Installer zip file
installer-zip:
cd installer && ./create_installer.sh

View File

@ -25,7 +25,6 @@ from ..services.invocation_cache.invocation_cache_memory import MemoryInvocation
from ..services.invocation_services import InvocationServices
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
from ..services.invoker import Invoker
from ..services.model_images.model_images_default import ModelImageFileStorageDisk
from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService
@ -72,8 +71,6 @@ class ApiDependencies:
image_files = DiskImageFileStorage(f"{output_folder}/images")
model_images_folder = config.models_path
db = init_db(config=config, logger=logger, image_files=image_files)
configuration = config
@ -95,7 +92,6 @@ class ApiDependencies:
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
)
download_queue_service = DownloadQueueService(event_bus=events)
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
model_manager = ModelManagerService.build_model_manager(
app_config=configuration,
model_record_service=ModelRecordServiceSQL(db=db),
@ -122,7 +118,6 @@ class ApiDependencies:
images=images,
invocation_cache=invocation_cache,
logger=logger,
model_images=model_images_service,
model_manager=model_manager,
download_queue=download_queue_service,
names=names,

View File

@ -1,16 +1,12 @@
# Copyright (c) 2023 Lincoln D. Stein
"""FastAPI route for model configuration records."""
import io
import pathlib
import shutil
import traceback
from typing import Any, Dict, List, Optional
from fastapi import Body, Path, Query, Response, UploadFile
from fastapi.responses import FileResponse
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
from PIL import Image
from pydantic import BaseModel, ConfigDict, Field
from starlette.exceptions import HTTPException
from typing_extensions import Annotated
@ -35,9 +31,6 @@ from ..dependencies import ApiDependencies
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
# images are immutable; set a high max-age
IMAGE_MAX_AGE = 31536000
class ModelsList(BaseModel):
"""Return list of configs."""
@ -112,9 +105,6 @@ async def list_model_records(
found_models.extend(
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
)
for model in found_models:
cover_image = ApiDependencies.invoker.services.model_images.get_url(model.key)
model.cover_image = cover_image
return ModelsList(models=found_models)
@ -158,8 +148,6 @@ async def get_model_record(
record_store = ApiDependencies.invoker.services.model_manager.store
try:
config: AnyModelConfig = record_store.get_model(key)
cover_image = ApiDependencies.invoker.services.model_images.get_url(key)
config.cover_image = cover_image
return config
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
@ -278,75 +266,6 @@ async def update_model_record(
return model_response
@model_manager_router.get(
"/i/{key}/image",
operation_id="get_model_image",
responses={
200: {
"description": "The model image was fetched successfully",
},
400: {"description": "Bad request"},
404: {"description": "The model image could not be found"},
},
status_code=200,
)
async def get_model_image(
key: str = Path(description="The name of model image file to get"),
) -> FileResponse:
"""Gets an image file that previews the model"""
try:
path = ApiDependencies.invoker.services.model_images.get_path(key)
response = FileResponse(
path,
media_type="image/png",
filename=key + ".png",
content_disposition_type="inline",
)
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
return response
except Exception:
raise HTTPException(status_code=404)
@model_manager_router.patch(
"/i/{key}/image",
operation_id="update_model_image",
responses={
200: {
"description": "The model image was updated successfully",
},
400: {"description": "Bad request"},
},
status_code=200,
)
async def update_model_image(
key: Annotated[str, Path(description="Unique key of model")],
image: UploadFile,
) -> None:
if not image.content_type or not image.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await image.read()
try:
pil_image = Image.open(io.BytesIO(contents))
except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image")
logger = ApiDependencies.invoker.services.logger
model_images = ApiDependencies.invoker.services.model_images
try:
model_images.save(pil_image, key)
logger.info(f"Updated image for model: {key}")
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
return
@model_manager_router.delete(
"/i/{key}",
operation_id="delete_model",
@ -377,29 +296,6 @@ async def delete_model(
raise HTTPException(status_code=404, detail=str(e))
@model_manager_router.delete(
"/i/{key}/image",
operation_id="delete_model_image",
responses={
204: {"description": "Model image deleted successfully"},
404: {"description": "Model image not found"},
},
status_code=204,
)
async def delete_model_image(
key: str = Path(description="Unique key of model image to remove from model_images directory."),
) -> None:
logger = ApiDependencies.invoker.services.logger
model_images = ApiDependencies.invoker.services.model_images
try:
model_images.delete(key)
logger.info(f"Deleted model image: {key}")
return
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
# @model_manager_router.post(
# "/i/",
# operation_id="add_model_record",
@ -643,7 +539,7 @@ async def convert_model(
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
# loading the model will convert it into a cached diffusers file
model_manager.load.load_model(model_config, submodel_type=SubModelType.Scheduler)
model_manager.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
# Get the path of the converted model from the loader
cache_path = loader.convert_cache.cache_path(key)

View File

@ -2,9 +2,12 @@
# which are imported/used before parse_args() is called will get the default config values instead of the
# values from the command line or config file.
import sys
from contextlib import asynccontextmanager
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.version.invokeai_version import __version__
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
from .services.config import InvokeAIAppConfig
app_config = InvokeAIAppConfig.get_config()
@ -17,7 +20,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
import asyncio
import mimetypes
import socket
from contextlib import asynccontextmanager
from inspect import signature
from pathlib import Path
from typing import Any
@ -38,7 +40,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
# noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies
@ -58,7 +59,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
BaseInvocation,
UIConfigBase,
)
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
if is_mps_available():
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)

View File

@ -20,7 +20,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
from invokeai.backend.util.devices import torch_dtype
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .model import CLIPField
from .model import ClipField
# unconditioned: Optional[torch.Tensor]
@ -46,7 +46,7 @@ class CompelInvocation(BaseInvocation):
description=FieldDescriptions.compel_prompt,
ui_component=UIComponent.Textarea,
)
clip: CLIPField = InputField(
clip: ClipField = InputField(
title="CLIP",
description=FieldDescriptions.clip,
input=Input.Connection,
@ -54,16 +54,16 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(self.clip.tokenizer)
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(self.clip.text_encoder)
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, CLIPTextModel)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(lora.lora)
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
@ -127,16 +127,16 @@ class SDXLPromptInvocationBase:
def run_clip_compel(
self,
context: InvocationContext,
clip_field: CLIPField,
clip_field: ClipField,
prompt: str,
get_pooled: bool,
lora_prefix: str,
zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
tokenizer_info = context.models.load(clip_field.tokenizer)
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(clip_field.text_encoder)
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
@ -163,7 +163,7 @@ class SDXLPromptInvocationBase:
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_field.loras:
lora_info = context.models.load(lora.lora)
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
lora_model = lora_info.model
assert isinstance(lora_model, LoRAModelRaw)
yield (lora_model, lora.weight)
@ -253,8 +253,8 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
crop_left: int = InputField(default=0, description="")
target_width: int = InputField(default=1024, description="")
target_height: int = InputField(default=1024, description="")
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
@ -340,7 +340,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
crop_top: int = InputField(default=0, description="")
crop_left: int = InputField(default=0, description="")
aesthetic_score: float = InputField(default=6.0, description=FieldDescriptions.sdxl_aesthetic)
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
@ -370,10 +370,10 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
@invocation_output("clip_skip_output")
class CLIPSkipInvocationOutput(BaseInvocationOutput):
"""CLIP skip node output"""
class ClipSkipInvocationOutput(BaseInvocationOutput):
"""Clip skip node output"""
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@invocation(
@ -383,15 +383,15 @@ class CLIPSkipInvocationOutput(BaseInvocationOutput):
category="conditioning",
version="1.0.0",
)
class CLIPSkipInvocation(BaseInvocation):
class ClipSkipInvocation(BaseInvocation):
"""Skip layers in clip text_encoder model."""
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers)
def invoke(self, context: InvocationContext) -> CLIPSkipInvocationOutput:
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
self.clip.skipped_layers += self.skipped_layers
return CLIPSkipInvocationOutput(
return ClipSkipInvocationOutput(
clip=self.clip,
)

View File

@ -34,7 +34,6 @@ from invokeai.app.invocations.fields import (
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import ModelField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
@ -52,9 +51,15 @@ CONTROLNET_RESIZE_VALUES = Literal[
]
class ControlNetModelField(BaseModel):
"""ControlNet model field"""
key: str = Field(description="Model config record key for the ControlNet model")
class ControlField(BaseModel):
image: ImageField = Field(description="The control image")
control_model: ModelField = Field(description="The ControlNet model to use")
control_model: ControlNetModelField = Field(description="The ControlNet model to use")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
@ -90,7 +95,7 @@ class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
image: ImageField = InputField(description="The control image")
control_model: ModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
control_weight: Union[float, List[float]] = InputField(
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
)

View File

@ -228,7 +228,7 @@ class ConditioningField(BaseModel):
# endregion
class MetadataField(RootModel[dict[str, Any]]):
class MetadataField(RootModel):
"""
Pydantic model for metadata with custom root of type dict[str, Any].
Metadata is stored without a strict schema.

View File

@ -11,17 +11,25 @@ from invokeai.app.invocations.baseinvocation import (
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.model import ModelField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType, IPAdapterConfig, ModelType
from invokeai.backend.model_manager.config import BaseModelType, ModelType
# LS: Consider moving these two classes into model.py
class IPAdapterModelField(BaseModel):
key: str = Field(description="Key to the IP-Adapter model")
class CLIPVisionModelField(BaseModel):
key: str = Field(description="Key to the CLIP Vision image encoder model")
class IPAdapterField(BaseModel):
image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).")
ip_adapter_model: ModelField = Field(description="The IP-Adapter model to use.")
image_encoder_model: ModelField = Field(description="The name of the CLIP image encoder model.")
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.")
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
@ -54,7 +62,7 @@ class IPAdapterInvocation(BaseInvocation):
# Inputs
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
ip_adapter_model: ModelField = InputField(
ip_adapter_model: IPAdapterModelField = InputField(
description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
)
@ -82,18 +90,18 @@ class IPAdapterInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
assert isinstance(ip_adapter_info, IPAdapterConfig)
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
image_encoder_models = context.models.search_by_attrs(
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
)
assert len(image_encoder_models) == 1
image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key)
return IPAdapterOutput(
ip_adapter=IPAdapterField(
image=self.image,
ip_adapter_model=self.ip_adapter_model,
image_encoder_model=ModelField(key=image_encoder_models[0].key),
image_encoder_model=image_encoder_model,
weight=self.weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,

View File

@ -26,7 +26,6 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
from PIL import Image, ImageFilter
from pydantic import field_validator
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPVisionModelWithProjection
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.fields import (
@ -76,7 +75,7 @@ from .baseinvocation import (
invocation_output,
)
from .controlnet_image_processors import ControlField
from .model import ModelField, UNetField, VAEField
from .model import ModelInfo, UNetField, VaeField
if choose_torch_device() == torch.device("mps"):
from torch import mps
@ -119,7 +118,7 @@ class SchedulerInvocation(BaseInvocation):
class CreateDenoiseMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
vae: VaeField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
@ -154,7 +153,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
)
if image_tensor is not None:
vae_info = context.models.load(self.vae.vae)
vae_info = context.models.load(**self.vae.vae.model_dump())
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
@ -245,12 +244,12 @@ class CreateGradientMaskInvocation(BaseInvocation):
def get_scheduler(
context: InvocationContext,
scheduler_info: ModelField,
scheduler_info: ModelInfo,
scheduler_name: str,
seed: int,
) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.models.load(scheduler_info)
orig_scheduler_info = context.models.load(**scheduler_info.model_dump())
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
@ -462,7 +461,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# and if weight is None, populate with default 1.0?
controlnet_data = []
for control_info in control_list:
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
control_model = exit_stack.enter_context(context.models.load(key=control_info.control_model.key))
# control_models.append(control_model)
control_image_field = control_info.image
@ -524,10 +523,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
conditioning_data.ip_adapter_conditioning = []
for single_ip_adapter in ip_adapter:
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.models.load(single_ip_adapter.ip_adapter_model)
context.models.load(key=single_ip_adapter.ip_adapter_model.key)
)
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
image_encoder_model_info = context.models.load(key=single_ip_adapter.image_encoder_model.key)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_image_fields = single_ip_adapter.image
if not isinstance(single_ipa_image_fields, list):
@ -538,7 +538,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
with image_encoder_model_info as image_encoder_model:
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
single_ipa_images, image_encoder_model
@ -578,8 +577,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
t2i_adapter_model_config = context.models.get_config(key=t2i_adapter_field.t2i_adapter_model.key)
t2i_adapter_loaded_model = context.models.load(key=t2i_adapter_field.t2i_adapter_model.key)
image = context.images.get_pil(t2i_adapter_field.image.image_name)
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
@ -732,13 +731,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
yield (lora_info.model, lora.weight)
del lora_info
return
unet_info = context.models.load(self.unet.unet)
unet_info = context.models.load(**self.unet.unet.model_dump())
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
ExitStack() as exit_stack,
@ -832,7 +830,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
description=FieldDescriptions.latents,
input=Input.Connection,
)
vae: VAEField = InputField(
vae: VaeField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
@ -843,8 +841,8 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL))
vae_info = context.models.load(**self.vae.vae.model_dump())
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, torch.nn.Module)
latents = latents.to(vae.device)
@ -1010,7 +1008,7 @@ class ImageToLatentsInvocation(BaseInvocation):
image: ImageField = InputField(
description="The image to encode",
)
vae: VAEField = InputField(
vae: VaeField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
@ -1066,7 +1064,7 @@ class ImageToLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(self.vae.vae)
vae_info = context.models.load(**self.vae.vae.model_dump())
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:

View File

@ -8,10 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.controlnet_image_processors import (
CONTROLNET_MODE_VALUES,
CONTROLNET_RESIZE_VALUES,
)
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
@ -20,8 +17,10 @@ from invokeai.app.invocations.fields import (
OutputField,
UIType,
)
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType, ModelType
from ...version import __version__
@ -31,20 +30,10 @@ class MetadataItemField(BaseModel):
value: Any = Field(description=FieldDescriptions.metadata_item_value)
class ModelMetadataField(BaseModel):
"""Model Metadata Field"""
key: str
hash: str
name: str
base: BaseModelType
type: ModelType
class LoRAMetadataField(BaseModel):
"""LoRA Metadata Field"""
model: ModelMetadataField = Field(description=FieldDescriptions.lora_model)
model: LoRAModelField = Field(description=FieldDescriptions.lora_model)
weight: float = Field(description=FieldDescriptions.lora_weight)
@ -52,7 +41,7 @@ class IPAdapterMetadataField(BaseModel):
"""IP Adapter Field, minus the CLIP Vision Encoder model"""
image: ImageField = Field(description="The IP-Adapter image prompt.")
ip_adapter_model: ModelMetadataField = Field(
ip_adapter_model: IPAdapterModelField = Field(
description="The IP-Adapter model.",
)
weight: Union[float, list[float]] = Field(
@ -62,33 +51,6 @@ class IPAdapterMetadataField(BaseModel):
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
class T2IAdapterMetadataField(BaseModel):
image: ImageField = Field(description="The T2I-Adapter image prompt.")
t2i_adapter_model: ModelMetadataField = Field(description="The T2I-Adapter model to use.")
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)"
)
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
class ControlNetMetadataField(BaseModel):
image: ImageField = Field(description="The control image")
control_model: ModelMetadataField = Field(description="The ControlNet model to use")
control_weight: Union[float, list[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@invocation_output("metadata_item_output")
class MetadataItemOutput(BaseInvocationOutput):
"""Metadata Item Output"""
@ -178,14 +140,14 @@ class CoreMetadataInvocation(BaseInvocation):
default=None,
description="The number of skipped CLIP layers",
)
model: Optional[ModelMetadataField] = InputField(default=None, description="The main model used for inference")
controlnets: Optional[list[ControlNetMetadataField]] = InputField(
model: Optional[MainModelField] = InputField(default=None, description="The main model used for inference")
controlnets: Optional[list[ControlField]] = InputField(
default=None, description="The ControlNets used for inference"
)
ipAdapters: Optional[list[IPAdapterMetadataField]] = InputField(
default=None, description="The IP Adapters used for inference"
)
t2iAdapters: Optional[list[T2IAdapterMetadataField]] = InputField(
t2iAdapters: Optional[list[T2IAdapterField]] = InputField(
default=None, description="The IP Adapters used for inference"
)
loras: Optional[list[LoRAMetadataField]] = InputField(default=None, description="The LoRAs used for inference")
@ -197,7 +159,7 @@ class CoreMetadataInvocation(BaseInvocation):
default=None,
description="The name of the initial image",
)
vae: Optional[ModelMetadataField] = InputField(
vae: Optional[VAEModelField] = InputField(
default=None,
description="The VAE used for decoding, if the main model's default was not used",
)
@ -228,7 +190,7 @@ class CoreMetadataInvocation(BaseInvocation):
)
# SDXL Refiner
refiner_model: Optional[ModelMetadataField] = InputField(
refiner_model: Optional[MainModelField] = InputField(
default=None,
description="The SDXL Refiner model used",
)
@ -260,9 +222,10 @@ class CoreMetadataInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> MetadataOutput:
"""Collects and outputs a CoreMetadata object"""
as_dict = self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
as_dict["app_version"] = __version__
return MetadataOutput(metadata=MetadataField.model_validate(as_dict))
return MetadataOutput(
metadata=MetadataField.model_validate(
self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
)
)
model_config = ConfigDict(extra="allow")

View File

@ -6,8 +6,8 @@ from pydantic import BaseModel, Field
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import SubModelType
from ...backend.model_manager import SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -16,33 +16,33 @@ from .baseinvocation import (
)
class ModelField(BaseModel):
key: str = Field(description="Key of the model")
submodel_type: Optional[SubModelType] = Field(description="Submodel type", default=None)
class ModelInfo(BaseModel):
key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()")
submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
class LoRAField(BaseModel):
lora: ModelField = Field(description="Info to load lora model")
weight: float = Field(description="Weight to apply to lora model")
class LoraInfo(ModelInfo):
weight: float = Field(description="Lora's weight which to use when apply to model")
class UNetField(BaseModel):
unet: ModelField = Field(description="Info to load unet submodel")
scheduler: ModelField = Field(description="Info to load scheduler submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
unet: ModelInfo = Field(description="Info to load unet submodel")
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
class CLIPField(BaseModel):
tokenizer: ModelField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelField = Field(description="Info to load text_encoder submodel")
class ClipField(BaseModel):
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
class VAEField(BaseModel):
vae: ModelField = Field(description="Info to load vae submodel")
class VaeField(BaseModel):
# TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
@ -57,14 +57,14 @@ class UNetOutput(BaseInvocationOutput):
class VAEOutput(BaseInvocationOutput):
"""Base class for invocations that output a VAE field"""
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation_output("clip_output")
class CLIPOutput(BaseInvocationOutput):
"""Base class for invocations that output a CLIP field"""
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
@invocation_output("model_loader_output")
@ -74,6 +74,18 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
pass
class MainModelField(BaseModel):
"""Main model field"""
key: str = Field(description="Model key")
class LoRAModelField(BaseModel):
"""LoRA model field"""
key: str = Field(description="LoRA model key")
@invocation(
"main_model_loader",
title="Main Model",
@ -84,40 +96,62 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
model: ModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
# TODO: precision?
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
# TODO: not found exceptions
if not context.models.exists(self.model.key):
raise Exception(f"Unknown model {self.model.key}")
key = self.model.key
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
# TODO: not found exceptions
if not context.models.exists(key):
raise Exception(f"Unknown model {key}")
return ModelLoaderOutput(
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
vae=VAEField(vae=vae),
unet=UNetField(
unet=ModelInfo(
key=key,
submodel_type=SubModelType.UNet,
),
scheduler=ModelInfo(
key=key,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
key=key,
submodel_type=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
key=key,
submodel_type=SubModelType.TextEncoder,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
key=key,
submodel_type=SubModelType.VAE,
),
),
)
@invocation_output("lora_loader_output")
class LoRALoaderOutput(BaseInvocationOutput):
class LoraLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.1")
class LoRALoaderInvocation(BaseInvocation):
class LoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField(
default=None,
@ -125,41 +159,46 @@ class LoRALoaderInvocation(BaseInvocation):
input=Input.Connection,
title="UNet",
)
clip: Optional[CLIPField] = InputField(
clip: Optional[ClipField] = InputField(
default=None,
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP",
)
def invoke(self, context: InvocationContext) -> LoRALoaderOutput:
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise Exception(f"Unkown lora: {lora_key}!")
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'LoRA "{lora_key}" already applied to unet')
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'Lora "{lora_key}" already applied to unet')
if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras):
raise Exception(f'LoRA "{lora_key}" already applied to clip')
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
raise Exception(f'Lora "{lora_key}" already applied to clip')
output = LoRALoaderOutput()
output = LoraLoaderOutput()
if self.unet is not None:
output.unet = self.unet.model_copy(deep=True)
output.unet = copy.deepcopy(self.unet)
output.unet.loras.append(
LoRAField(
lora=self.lora,
LoraInfo(
key=lora_key,
submodel_type=None,
weight=self.weight,
)
)
if self.clip is not None:
output.clip = self.clip.model_copy(deep=True)
output.clip = copy.deepcopy(self.clip)
output.clip.loras.append(
LoRAField(
lora=self.lora,
LoraInfo(
key=lora_key,
submodel_type=None,
weight=self.weight,
)
)
@ -168,12 +207,12 @@ class LoRALoaderInvocation(BaseInvocation):
@invocation_output("sdxl_lora_loader_output")
class SDXLLoRALoaderOutput(BaseInvocationOutput):
class SDXLLoraLoaderOutput(BaseInvocationOutput):
"""SDXL LoRA Loader Output"""
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
clip2: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
@invocation(
@ -183,10 +222,10 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput):
category="model",
version="1.0.1",
)
class SDXLLoRALoaderInvocation(BaseInvocation):
class SDXLLoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField(
default=None,
@ -194,59 +233,65 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
input=Input.Connection,
title="UNet",
)
clip: Optional[CLIPField] = InputField(
clip: Optional[ClipField] = InputField(
default=None,
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP 1",
)
clip2: Optional[CLIPField] = InputField(
clip2: Optional[ClipField] = InputField(
default=None,
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP 2",
)
def invoke(self, context: InvocationContext) -> SDXLLoRALoaderOutput:
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise Exception(f"Unknown lora: {lora_key}!")
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'LoRA "{lora_key}" already applied to unet')
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'Lora "{lora_key}" already applied to unet')
if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras):
raise Exception(f'LoRA "{lora_key}" already applied to clip')
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
raise Exception(f'Lora "{lora_key}" already applied to clip')
if self.clip2 is not None and any(lora.lora.key == lora_key for lora in self.clip2.loras):
raise Exception(f'LoRA "{lora_key}" already applied to clip2')
if self.clip2 is not None and any(lora.key == lora_key for lora in self.clip2.loras):
raise Exception(f'Lora "{lora_key}" already applied to clip2')
output = SDXLLoRALoaderOutput()
output = SDXLLoraLoaderOutput()
if self.unet is not None:
output.unet = self.unet.model_copy(deep=True)
output.unet = copy.deepcopy(self.unet)
output.unet.loras.append(
LoRAField(
lora=self.lora,
LoraInfo(
key=lora_key,
submodel_type=None,
weight=self.weight,
)
)
if self.clip is not None:
output.clip = self.clip.model_copy(deep=True)
output.clip = copy.deepcopy(self.clip)
output.clip.loras.append(
LoRAField(
lora=self.lora,
LoraInfo(
key=lora_key,
submodel_type=None,
weight=self.weight,
)
)
if self.clip2 is not None:
output.clip2 = self.clip2.model_copy(deep=True)
output.clip2 = copy.deepcopy(self.clip2)
output.clip2.loras.append(
LoRAField(
lora=self.lora,
LoraInfo(
key=lora_key,
submodel_type=None,
weight=self.weight,
)
)
@ -254,11 +299,17 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
return output
class VAEModelField(BaseModel):
"""Vae model field"""
key: str = Field(description="Model's key")
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1")
class VAELoaderInvocation(BaseInvocation):
class VaeLoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""
vae_model: ModelField = InputField(
vae_model: VAEModelField = InputField(
description=FieldDescriptions.vae_model,
input=Input.Direct,
title="VAE",
@ -270,7 +321,7 @@ class VAELoaderInvocation(BaseInvocation):
if not context.models.exists(key):
raise Exception(f"Unkown vae: {key}!")
return VAEOutput(vae=VAEField(vae=self.vae_model))
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))
@invocation_output("seamless_output")
@ -278,7 +329,7 @@ class SeamlessModeOutput(BaseInvocationOutput):
"""Modified Seamless Model output"""
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
vae: Optional[VAEField] = OutputField(default=None, description=FieldDescriptions.vae, title="VAE")
vae: Optional[VaeField] = OutputField(default=None, description=FieldDescriptions.vae, title="VAE")
@invocation(
@ -297,7 +348,7 @@ class SeamlessModeInvocation(BaseInvocation):
input=Input.Connection,
title="UNet",
)
vae: Optional[VAEField] = InputField(
vae: Optional[VaeField] = InputField(
default=None,
description=FieldDescriptions.vae_model,
input=Input.Connection,

View File

@ -8,7 +8,7 @@ from .baseinvocation import (
invocation,
invocation_output,
)
from .model import CLIPField, ModelField, UNetField, VAEField
from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField
@invocation_output("sdxl_model_loader_output")
@ -16,9 +16,9 @@ class SDXLModelLoaderOutput(BaseInvocationOutput):
"""SDXL base model loader output"""
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation_output("sdxl_refiner_model_loader_output")
@ -26,15 +26,15 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
"""SDXL refiner model loader output"""
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.1")
class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels."""
model: ModelField = InputField(
model: MainModelField = InputField(
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
)
# TODO: precision?
@ -46,19 +46,48 @@ class SDXLModelLoaderInvocation(BaseInvocation):
if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return SDXLModelLoaderOutput(
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
vae=VAEField(vae=vae),
unet=UNetField(
unet=ModelInfo(
key=model_key,
submodel_type=SubModelType.UNet,
),
scheduler=ModelInfo(
key=model_key,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel_type=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
key=model_key,
submodel_type=SubModelType.TextEncoder,
),
loras=[],
skipped_layers=0,
),
clip2=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel_type=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
key=model_key,
submodel_type=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
key=model_key,
submodel_type=SubModelType.VAE,
),
),
)
@ -72,8 +101,10 @@ class SDXLModelLoaderInvocation(BaseInvocation):
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels."""
model: ModelField = InputField(
description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel
model: MainModelField = InputField(
description=FieldDescriptions.sdxl_refiner_model,
input=Input.Direct,
ui_type=UIType.SDXLRefinerModel,
)
# TODO: precision?
@ -84,14 +115,34 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return SDXLRefinerModelLoaderOutput(
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
vae=VAEField(vae=vae),
unet=UNetField(
unet=ModelInfo(
key=model_key,
submodel_type=SubModelType.UNet,
),
scheduler=ModelInfo(
key=model_key,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip2=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel_type=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
key=model_key,
submodel_type=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
key=model_key,
submodel_type=SubModelType.VAE,
),
),
)

View File

@ -10,14 +10,17 @@ from invokeai.app.invocations.baseinvocation import (
)
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
from invokeai.app.invocations.model import ModelField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
class T2IAdapterModelField(BaseModel):
key: str = Field(description="Model record key for the T2I-Adapter model")
class T2IAdapterField(BaseModel):
image: ImageField = Field(description="The T2I-Adapter image prompt.")
t2i_adapter_model: ModelField = Field(description="The T2I-Adapter model to use.")
t2i_adapter_model: T2IAdapterModelField = Field(description="The T2I-Adapter model to use.")
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
@ -52,7 +55,7 @@ class T2IAdapterInvocation(BaseInvocation):
# Inputs
image: ImageField = InputField(description="The IP-Adapter image prompt.")
t2i_adapter_model: ModelField = InputField(
t2i_adapter_model: T2IAdapterModelField = InputField(
description="The T2I-Adapter model.",
title="T2I-Adapter Model",
input=Input.Direct,

View File

@ -41,9 +41,8 @@ class InvocationCacheBase(ABC):
"""Clears the cache"""
pass
@staticmethod
@abstractmethod
def create_key(invocation: BaseInvocation) -> int:
def create_key(self, invocation: BaseInvocation) -> int:
"""Gets the key for the invocation's cache item"""
pass

View File

@ -61,7 +61,9 @@ class MemoryInvocationCache(InvocationCacheBase):
self._delete_oldest_access(number_to_delete)
self._cache[key] = CachedItem(
invocation_output,
invocation_output.model_dump_json(warnings=False, exclude_defaults=True, exclude_unset=True),
invocation_output.model_dump_json(
warnings=False, exclude_defaults=True, exclude_unset=True, include={"type"}
),
)
def _delete_oldest_access(self, number_to_delete: int) -> None:
@ -79,7 +81,7 @@ class MemoryInvocationCache(InvocationCacheBase):
with self._lock:
return self._delete(key)
def clear(self) -> None:
def clear(self, *args, **kwargs) -> None:
with self._lock:
if self._max_cache_size == 0:
return

View File

@ -25,7 +25,6 @@ if TYPE_CHECKING:
from .images.images_base import ImageServiceABC
from .invocation_cache.invocation_cache_base import InvocationCacheBase
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
from .model_images.model_images_base import ModelImageFileStorageBase
from .model_manager.model_manager_base import ModelManagerServiceBase
from .names.names_base import NameServiceBase
from .session_processor.session_processor_base import SessionProcessorBase
@ -50,7 +49,6 @@ class InvocationServices:
image_files: "ImageFileStorageBase",
image_records: "ImageRecordStorageBase",
logger: "Logger",
model_images: "ModelImageFileStorageBase",
model_manager: "ModelManagerServiceBase",
download_queue: "DownloadQueueServiceBase",
performance_statistics: "InvocationStatsServiceBase",
@ -74,7 +72,6 @@ class InvocationServices:
self.image_files = image_files
self.image_records = image_records
self.logger = logger
self.model_images = model_images
self.model_manager = model_manager
self.download_queue = download_queue
self.performance_statistics = performance_statistics

View File

@ -1,33 +0,0 @@
from abc import ABC, abstractmethod
from pathlib import Path
from PIL.Image import Image as PILImageType
class ModelImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files."""
@abstractmethod
def get(self, model_key: str) -> PILImageType:
"""Retrieves a model image as PIL Image."""
pass
@abstractmethod
def get_path(self, model_key: str) -> Path:
"""Gets the internal path to a model image."""
pass
@abstractmethod
def get_url(self, model_key: str) -> str | None:
"""Gets the URL to fetch a model image."""
pass
@abstractmethod
def save(self, image: PILImageType, model_key: str) -> None:
"""Saves a model image."""
pass
@abstractmethod
def delete(self, model_key: str) -> None:
"""Deletes a model image."""
pass

View File

@ -1,20 +0,0 @@
# TODO: Should these excpetions subclass existing python exceptions?
class ModelImageFileNotFoundException(Exception):
"""Raised when an image file is not found in storage."""
def __init__(self, message="Model image file not found"):
super().__init__(message)
class ModelImageFileSaveException(Exception):
"""Raised when an image cannot be saved."""
def __init__(self, message="Model image file not saved"):
super().__init__(message)
class ModelImageFileDeleteException(Exception):
"""Raised when an image cannot be deleted."""
def __init__(self, message="Model image file not deleted"):
super().__init__(message)

View File

@ -1,79 +0,0 @@
from pathlib import Path
from PIL import Image
from PIL.Image import Image as PILImageType
from send2trash import send2trash
from invokeai.app.services.invoker import Invoker
from invokeai.app.util.thumbnails import make_thumbnail
from .model_images_base import ModelImageFileStorageBase
from .model_images_common import (
ModelImageFileDeleteException,
ModelImageFileNotFoundException,
ModelImageFileSaveException,
)
class ModelImageFileStorageDisk(ModelImageFileStorageBase):
"""Stores images on disk"""
def __init__(self, model_images_folder: Path):
self._model_images_folder = model_images_folder
self._validate_storage_folders()
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
def get(self, model_key: str) -> PILImageType:
try:
path = self.get_path(model_key)
if not self._validate_path(path):
raise ModelImageFileNotFoundException
return Image.open(path)
except FileNotFoundError as e:
raise ModelImageFileNotFoundException from e
def save(self, image: PILImageType, model_key: str) -> None:
try:
self._validate_storage_folders()
image_path = self._model_images_folder / (model_key + ".webp")
thumbnail = make_thumbnail(image, 256)
thumbnail.save(image_path, format="webp")
except Exception as e:
raise ModelImageFileSaveException from e
def get_path(self, model_key: str) -> Path:
path = self._model_images_folder / (model_key + ".webp")
return path
def get_url(self, model_key: str) -> str | None:
path = self.get_path(model_key)
if not self._validate_path(path):
return
return self._invoker.services.urls.get_model_image_url(model_key)
def delete(self, model_key: str) -> None:
try:
path = self.get_path(model_key)
if not self._validate_path(path):
raise ModelImageFileNotFoundException
send2trash(path)
except Exception as e:
raise ModelImageFileDeleteException from e
def _validate_path(self, path: Path) -> bool:
"""Validates the path given for an image."""
return path.exists()
def _validate_storage_folders(self) -> None:
"""Checks if the required folders exist and create them if they don't"""
self._model_images_folder.mkdir(parents=True, exist_ok=True)

View File

@ -4,7 +4,6 @@ import os
import re
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from hashlib import sha256
from pathlib import Path
from queue import Empty, Queue
@ -281,18 +280,12 @@ class ModelInstallService(ModelInstallServiceBase):
self._scan_models_directory()
if autoimport := self._app_config.autoimport_dir:
self._logger.info("Scanning autoimport directory for new models")
installed: List[str] = []
# Use ThreadPoolExecutor to scan dirs in parallel
with ThreadPoolExecutor() as executor:
future_models = [executor.submit(self.scan_directory, self._app_config.root_path / autoimport / cur_model_type.value) for cur_model_type in ModelType]
[installed.extend(models.result()) for models in as_completed(future_models)]
installed = self.scan_directory(self._app_config.root_path / autoimport)
self._logger.info(f"{len(installed)} new models registered")
self._logger.info("Model installer (re)initialized")
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()}
if len([entry for entry in os.scandir(scan_dir) if not entry.name.startswith(".")]) == 0:
return []
callback = self._scan_install if install else self._scan_register
search = ModelSearch(on_model_found=callback, config=self._app_config)
self._models_installed.clear()
@ -455,10 +448,10 @@ class ModelInstallService(ModelInstallServiceBase):
self.unregister(key)
self._logger.info(f"Scanning {self._app_config.models_path} for new and orphaned models")
# Use ThreadPoolExecutor to scan dirs in parallel
with ThreadPoolExecutor() as executor:
future_models = [executor.submit(self.scan_directory, Path(cur_base_model.value, cur_model_type.value)) for cur_base_model in BaseModelType for cur_model_type in ModelType]
[installed.update(models.result()) for models in as_completed(future_models)]
for cur_base_model in BaseModelType:
for cur_model_type in ModelType:
models_dir = Path(cur_base_model.value, cur_model_type.value)
installed.update(self.scan_directory(models_dir))
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
def _sync_model_path(self, key: str) -> AnyModelConfig:

View File

@ -1,11 +1,15 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
from abc import ABC, abstractmethod
from typing import Optional
import torch
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel
from ..config import InvokeAIAppConfig
from ..download import DownloadQueueServiceBase
@ -66,3 +70,32 @@ class ModelManagerServiceBase(ABC):
@abstractmethod
def stop(self, invoker: Invoker) -> None:
pass
@abstractmethod
def load_model_by_config(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
pass
@abstractmethod
def load_model_by_key(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
pass
@abstractmethod
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
pass

View File

@ -1,10 +1,14 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase."""
from typing import Optional
import torch
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, LoadedModel, ModelType, SubModelType
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
@ -14,7 +18,7 @@ from ..download import DownloadQueueServiceBase
from ..events.events_base import EventServiceBase
from ..model_install import ModelInstallService, ModelInstallServiceBase
from ..model_load import ModelLoadService, ModelLoadServiceBase
from ..model_records import ModelRecordServiceBase
from ..model_records import ModelRecordServiceBase, UnknownModelException
from .model_manager_base import ModelManagerServiceBase
@ -60,6 +64,56 @@ class ModelManagerService(ModelManagerServiceBase):
if hasattr(service, "stop"):
service.stop(invoker)
def load_model_by_config(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
return self.load.load_model(model_config, submodel_type, context_data)
def load_model_by_key(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
config = self.store.get_model(key)
return self.load.load_model(config, submodel_type, context_data)
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
"""
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
This is provided for API compatability with the get_model() method
in the original model manager. However, note that LoadedModel is
not the same as the original ModelInfo that ws returned.
:param model_name: Name of to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
:param context: The invocation context.
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
ValueError -- more than one model matches this combination
"""
configs = self.store.search_by_attr(model_name, base_model, model_type)
if len(configs) == 0:
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
elif len(configs) > 1:
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
else:
return self.load.load_model(configs[0], submodel, context_data)
@classmethod
def build_model_manager(
cls,

View File

@ -79,7 +79,6 @@ class ModelRecordChanges(BaseModelExcludeNull):
description="The prediction type of the model.", default=None
)
upcast_attention: Optional[bool] = Field(description="Whether to upcast attention.", default=None)
config_path: Optional[str] = Field(description="Path to config file for model", default=None)
class ModelRecordServiceBase(ABC):
@ -130,17 +129,6 @@ class ModelRecordServiceBase(ABC):
"""
pass
@abstractmethod
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
"""
Retrieve the configuration for the indicated model.
:param hash: Hash of model config to be fetched.
Exceptions: UnknownModelException
"""
pass
@abstractmethod
def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default

View File

@ -203,21 +203,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def exists(self, key: str) -> bool:
"""
Return True if a model with the indicated key exists in the databse.

View File

@ -1,6 +1,35 @@
from abc import ABC, abstractmethod
from threading import Event
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
class SessionRunnerBase(ABC):
"""
Base class for session runner.
"""
@abstractmethod
def start(self, services: InvocationServices, cancel_event: Event) -> None:
"""Starts the session runner"""
pass
@abstractmethod
def run(self, queue_item: SessionQueueItem) -> None:
"""Runs the session"""
pass
@abstractmethod
def complete(self, queue_item: SessionQueueItem) -> None:
"""Completes the session"""
pass
@abstractmethod
def run_node(self, node_id: str, queue_item: SessionQueueItem) -> None:
"""Runs an already prepared node on the session"""
pass
class SessionProcessorBase(ABC):

View File

@ -2,13 +2,14 @@ import traceback
from contextlib import suppress
from threading import BoundedSemaphore, Thread
from threading import Event as ThreadEvent
from typing import Optional
from typing import Callable, Optional, Union
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event as FastAPIEvent
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
@ -16,15 +17,164 @@ from invokeai.app.services.shared.invocation_context import InvocationContextDat
from invokeai.app.util.profiler import Profiler
from ..invoker import Invoker
from .session_processor_base import SessionProcessorBase
from .session_processor_base import SessionProcessorBase, SessionRunnerBase
from .session_processor_common import SessionProcessorStatus
class DefaultSessionRunner(SessionRunnerBase):
"""Processes a single session's invocations"""
def __init__(
self,
on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
):
self.on_before_run_node = on_before_run_node
self.on_after_run_node = on_after_run_node
def start(self, services: InvocationServices, cancel_event: ThreadEvent):
"""Start the session runner"""
self.services = services
self.cancel_event = cancel_event
def run(self, queue_item: SessionQueueItem):
"""Run the graph"""
if not queue_item.session:
raise ValueError("Queue item has no session")
# Loop over invocations until the session is complete or canceled
while not (queue_item.session.is_complete() or self.cancel_event.is_set()):
# Prepare the next node
invocation = queue_item.session.next()
if invocation is None:
# If there are no more invocations, complete the graph
break
# Build invocation context (the node-facing API
self.run_node(invocation.id, queue_item)
self.complete(queue_item)
def complete(self, queue_item: SessionQueueItem):
"""Complete the graph"""
self.services.events.emit_graph_execution_complete(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
)
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run before a node is executed"""
# Send starting event
self.services.events.emit_invocation_started(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session_id,
node=invocation.model_dump(),
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
)
if self.on_before_run_node is not None:
self.on_before_run_node(invocation, queue_item)
def _on_after_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run after a node is executed"""
if self.on_after_run_node is not None:
self.on_after_run_node(invocation, queue_item)
def run_node(self, node_id: str, queue_item: SessionQueueItem):
"""Run a single node in the graph"""
# If this error raises a NodeNotFoundError that's handled by the processor
invocation = queue_item.session.execution_graph.get_node(node_id)
try:
self._on_before_run_node(invocation, queue_item)
data = InvocationContextData(
invocation=invocation,
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
queue_item=queue_item,
)
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id):
context = build_invocation_context(
data=data,
services=self.services,
cancel_event=self.cancel_event,
)
# Invoke the node
outputs = invocation.invoke_internal(context=context, services=self.services)
# Save outputs and history
queue_item.session.complete(invocation.id, outputs)
self._on_after_run_node(invocation, queue_item)
# Send complete event on successful runs
self.services.events.emit_invocation_complete(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
node=invocation.model_dump(),
source_node_id=data.source_invocation_id,
result=outputs.model_dump(),
)
except KeyboardInterrupt:
# TODO(MM2): Create an event for this
pass
except CanceledException:
# When the user cancels the graph, we first set the cancel event. The event is checked
# between invocations, in this loop. Some invocations are long-running, and we need to
# be able to cancel them mid-execution.
#
# For example, denoising is a long-running invocation with many steps. A step callback
# is executed after each step. This step callback checks if the canceled event is set,
# then raises a CanceledException to stop execution immediately.
#
# When we get a CanceledException, we don't need to do anything - just pass and let the
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
except Exception as e:
error = traceback.format_exc()
# Save error
queue_item.session.set_node_error(invocation.id, error)
self.services.logger.error(
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}"
)
self.services.logger.error(error)
# Send error event
self.services.events.emit_invocation_error(
queue_batch_id=queue_item.session_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
node=invocation.model_dump(),
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
error_type=e.__class__.__name__,
error=error,
)
class DefaultSessionProcessor(SessionProcessorBase):
def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None:
"""Processes sessions from the session queue"""
def __init__(self, session_runner: Union[SessionRunnerBase, None] = None) -> None:
super().__init__()
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
def start(
self,
invoker: Invoker,
thread_limit: int = 1,
polling_interval: int = 1,
on_before_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
on_after_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
) -> None:
self._invoker: Invoker = invoker
self._queue_item: Optional[SessionQueueItem] = None
self._invocation: Optional[BaseInvocation] = None
self.on_before_run_session = on_before_run_session
self.on_after_run_session = on_after_run_session
self._resume_event = ThreadEvent()
self._stop_event = ThreadEvent()
@ -59,6 +209,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
"cancel_event": self._cancel_event,
},
)
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event)
self._thread.start()
def stop(self, *args, **kwargs) -> None:
@ -117,131 +268,34 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
cancel_event.clear()
# If we have a on_before_run_session callback, call it
if self.on_before_run_session is not None:
self.on_before_run_session(self._queue_item)
# If profiling is enabled, start the profiler
if self._profiler is not None:
self._profiler.start(profile_id=self._queue_item.session_id)
# Prepare invocations and take the first
self._invocation = self._queue_item.session.next()
# Run the graph
self.session_runner.run(queue_item=self._queue_item)
# Loop over invocations until the session is complete or canceled
while self._invocation is not None and not cancel_event.is_set():
# get the source node id to provide to clients (the prepared node id is not as useful)
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
# Send starting event
self._invoker.services.events.emit_invocation_started(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session_id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
self._invoker.services.performance_statistics.dump_stats(
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
)
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
try:
with self._invoker.services.performance_statistics.collect_stats(
self._invocation, self._queue_item.session.id
):
# Build invocation context (the node-facing API)
data = InvocationContextData(
invocation=self._invocation,
source_invocation_id=source_invocation_id,
queue_item=self._queue_item,
)
context = build_invocation_context(
data=data,
services=self._invoker.services,
cancel_event=self._cancel_event,
)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
self._invoker.services.performance_statistics.reset_stats()
# Invoke the node
outputs = self._invocation.invoke_internal(
context=context, services=self._invoker.services
)
# Save outputs and history
self._queue_item.session.complete(self._invocation.id, outputs)
# Send complete event
self._invoker.services.events.emit_invocation_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
result=outputs.model_dump(),
)
except KeyboardInterrupt:
# TODO(MM2): Create an event for this
pass
except CanceledException:
# When the user cancels the graph, we first set the cancel event. The event is checked
# between invocations, in this loop. Some invocations are long-running, and we need to
# be able to cancel them mid-execution.
#
# For example, denoising is a long-running invocation with many steps. A step callback
# is executed after each step. This step callback checks if the canceled event is set,
# then raises a CanceledException to stop execution immediately.
#
# When we get a CanceledException, we don't need to do anything - just pass and let the
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
except Exception as e:
error = traceback.format_exc()
# Save error
self._queue_item.session.set_node_error(self._invocation.id, error)
self._invoker.services.logger.error(
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
)
self._invoker.services.logger.error(error)
# Send error event
self._invoker.services.events.emit_invocation_error(
queue_batch_id=self._queue_item.session_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
error_type=e.__class__.__name__,
error=error,
)
pass
# The session is complete if the all invocations are complete or there was an error
if self._queue_item.session.is_complete() or cancel_event.is_set():
# Send complete event
self._invoker.services.events.emit_graph_execution_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
)
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
self._invoker.services.performance_statistics.dump_stats(
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
self._invoker.services.performance_statistics.reset_stats()
# Set the invocation to None to prepare for the next session
self._invocation = None
else:
# Prepare the next invocation
self._invocation = self._queue_item.session.next()
# If we have a on_after_run_session callback, call it
if self.on_after_run_session is not None:
self.on_after_run_session(self._queue_item)
# The session is complete, immediately poll for next session
self._queue_item = None
@ -275,3 +329,4 @@ class DefaultSessionProcessor(SessionProcessorBase):
poll_now_event.clear()
self._queue_item = None
self._thread_semaphore.release()
self._invoker.services.logger.debug("Session processor stopped")

View File

@ -1,7 +1,7 @@
import threading
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Optional
from PIL.Image import Image
from torch import Tensor
@ -13,16 +13,15 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_records.model_records_base import UnknownModelException
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.model_manager.metadata.metadata_base import AnyModelRepoMetadata
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.invocations.model import ModelField
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
"""
@ -300,25 +299,22 @@ class ConditioningInterface(InvocationContextInterface):
class ModelsInterface(InvocationContextInterface):
def exists(self, identifier: Union[str, "ModelField"]) -> bool:
def exists(self, key: str) -> bool:
"""Checks if a model exists.
Args:
identifier: The key or ModelField representing the model.
key: The key of the model.
Returns:
True if the model exists, False if not.
"""
if isinstance(identifier, str):
return self._services.model_manager.store.exists(identifier)
return self._services.model_manager.store.exists(key)
return self._services.model_manager.store.exists(identifier.key)
def load(self, identifier: Union[str, "ModelField"], submodel_type: Optional[SubModelType] = None) -> LoadedModel:
def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""Loads a model.
Args:
identifier: The key or ModelField representing the model.
key: The key of the model.
submodel_type: The submodel of the model to get.
Returns:
@ -328,13 +324,9 @@ class ModelsInterface(InvocationContextInterface):
# The model manager emits events as it loads the model. It needs the context data to build
# the event payloads.
if isinstance(identifier, str):
model = self._services.model_manager.store.get_model(identifier)
return self._services.model_manager.load.load_model(model, submodel_type, self._data)
else:
_submodel_type = submodel_type or identifier.submodel_type
model = self._services.model_manager.store.get_model(identifier.key)
return self._services.model_manager.load.load_model(model, _submodel_type, self._data)
return self._services.model_manager.load_model_by_key(
key=key, submodel_type=submodel_type, context_data=self._data
)
def load_by_attrs(
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
@ -351,29 +343,35 @@ class ModelsInterface(InvocationContextInterface):
Returns:
An object representing the loaded model.
"""
return self._services.model_manager.load_model_by_attr(
model_name=name,
base_model=base,
model_type=type,
submodel=submodel_type,
context_data=self._data,
)
configs = self._services.model_manager.store.search_by_attr(model_name=name, base_model=base, model_type=type)
if len(configs) == 0:
raise UnknownModelException(f"No model found with name {name}, base {base}, and type {type}")
if len(configs) > 1:
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
def get_config(self, identifier: Union[str, "ModelField"]) -> AnyModelConfig:
def get_config(self, key: str) -> AnyModelConfig:
"""Gets a model's config.
Args:
identifier: The key or ModelField representing the model.
key: The key of the model.
Returns:
The model's config.
"""
if isinstance(identifier, str):
return self._services.model_manager.store.get_model(identifier)
return self._services.model_manager.store.get_model(key=key)
return self._services.model_manager.store.get_model(identifier.key)
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
"""Gets a model's metadata, if it has any.
Args:
key: The key of the model.
Returns:
The model's metadata, if it has any.
"""
return self._services.model_manager.store.get_metadata(key=key)
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
"""Searches for models by path.

View File

@ -17,7 +17,8 @@ class MigrateCallback(Protocol):
See :class:`Migration` for an example.
"""
def __call__(self, cursor: sqlite3.Cursor) -> None: ...
def __call__(self, cursor: sqlite3.Cursor) -> None:
...
class MigrationError(RuntimeError):

View File

@ -8,8 +8,3 @@ class UrlServiceBase(ABC):
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets the URL for an image or thumbnail."""
pass
@abstractmethod
def get_model_image_url(self, model_key: str) -> str:
"""Gets the URL for a model image"""
pass

View File

@ -4,9 +4,8 @@ from .urls_base import UrlServiceBase
class LocalUrlService(UrlServiceBase):
def __init__(self, base_url: str = "api/v1", base_url_v2: str = "api/v2"):
def __init__(self, base_url: str = "api/v1"):
self._base_url = base_url
self._base_url_v2 = base_url_v2
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
image_basename = os.path.basename(image_name)
@ -16,6 +15,3 @@ class LocalUrlService(UrlServiceBase):
return f"{self._base_url}/images/i/{image_basename}/thumbnail"
return f"{self._base_url}/images/i/{image_basename}/full"
def get_model_image_url(self, model_key: str) -> str:
return f"{self._base_url_v2}/models/i/{model_key}/image"

View File

@ -22,7 +22,7 @@ def generate_ti_list(
for trigger in extract_ti_triggers_from_prompt(prompt):
name_or_key = trigger[1:-1]
try:
loaded_model = context.models.load(name_or_key)
loaded_model = context.models.load(key=name_or_key)
model = loaded_model.model
assert isinstance(model, TextualInversionModelRaw)
assert loaded_model.config.base == base

View File

@ -19,6 +19,7 @@ from invokeai.app.services.model_install import (
ModelInstallService,
ModelInstallServiceBase,
)
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager import (
@ -38,7 +39,7 @@ def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordService
logger = InvokeAILogger.get_logger(config=app_config)
image_files = DiskImageFileStorage(f"{app_config.output_path}/images")
db = init_db(config=app_config, logger=logger, image_files=image_files)
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db)
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
return obj

View File

@ -161,7 +161,6 @@ class ModelConfigBase(BaseModel):
default_settings: Optional[ModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
cover_image: Optional[str] = Field(description="Url for image to preview model", default=None)
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
@ -311,7 +310,7 @@ class IPAdapterConfig(ModelConfigBase):
class CLIPVisionDiffusersConfig(ModelConfigBase):
"""Model config for CLIPVision."""
"""Model config for ClipVision."""
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
format: Literal[ModelFormat.Diffusers]

View File

@ -12,8 +12,6 @@ import hashlib
import os
from pathlib import Path
from typing import Callable, Literal, Optional, Union
from concurrent.futures import ThreadPoolExecutor, as_completed
from blake3 import blake3
@ -107,14 +105,13 @@ class ModelHash:
"""
model_component_paths = self._get_file_paths(dir, self._file_filter)
# Use ThreadPoolExecutor to hash files in parallel
with ThreadPoolExecutor(min(((os.cpu_count() or 1) + 4), len(model_component_paths))) as executor:
future_to_component = {executor.submit(self._hash_file, component): component for component in sorted(model_component_paths)}
component_hashes = [future.result() for future in as_completed(future_to_component)]
component_hashes: list[str] = []
for component in sorted(model_component_paths):
component_hashes.append(self._hash_file(component))
# BLAKE3 to hash the hashes
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
# for the composite hash
composite_hasher = blake3()
component_hashes.sort()
for h in component_hashes:
composite_hasher.update(h.encode("utf-8"))
return composite_hasher.hexdigest()
@ -132,12 +129,10 @@ class ModelHash:
"""
files: list[Path] = []
entries = [entry for entry in os.scandir(model_path.as_posix()) if not entry.name.startswith(".")]
dirs = [entry for entry in entries if entry.is_dir()]
file_paths = [entry.path for entry in entries if entry.is_file() and file_filter(entry.path)]
files.extend([Path(file) for file in file_paths])
for dir in dirs:
files.extend(ModelHash._get_file_paths(Path(dir.path), file_filter))
for root, _dirs, _files in os.walk(model_path):
for file in _files:
if file_filter(file):
files.append(Path(root, file))
return files
@staticmethod
@ -166,11 +161,13 @@ class ModelHash:
"""
def hashlib_hasher(file_path: Path) -> str:
"""Hashes a file using a hashlib algorithm."""
"""Hashes a file using a hashlib algorithm. Uses `memoryview` to avoid reading the entire file into memory."""
hasher = hashlib.new(algorithm)
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(8 * 1024), b""):
hasher.update(chunk)
buffer = bytearray(128 * 1024)
mv = memoryview(buffer)
with open(file_path, "rb", buffering=0) as f:
while n := f.readinto(mv):
hasher.update(mv[:n])
return hasher.hexdigest()
return hashlib_hasher

View File

@ -24,7 +24,7 @@ from .. import ModelLoader, ModelLoaderRegistry
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
class LoRALoader(ModelLoader):
class LoraLoader(ModelLoader):
"""Class to load LoRA models."""
# We cheat a little bit to get access to the model base

View File

@ -23,7 +23,7 @@ from .generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
class VAELoader(GenericDiffusersLoader):
class VaeLoader(GenericDiffusersLoader):
"""Class to load VAE models."""
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:

View File

@ -84,9 +84,6 @@ class ProbeBase(object):
class ModelProbe(object):
hasher = ModelHash()
PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = {
"diffusers": {},
"checkpoint": {},
@ -160,7 +157,7 @@ class ModelProbe(object):
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
)
fields["format"] = fields.get("format") or probe.get_format()
fields["hash"] = fields.get("hash") or cls.hasher.hash(model_path)
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()

View File

@ -858,9 +858,9 @@ def do_textual_inversion_training(
# Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
with torch.no_grad():
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
orig_embeds_params[index_no_updates]
)
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
index_no_updates
] = orig_embeds_params[index_no_updates]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:

View File

@ -42,10 +42,9 @@ def install_and_load_model(
# If the requested model is already installed, return its LoadedModel
with contextlib.suppress(UnknownModelException):
# TODO: Replace with wrapper call
configs = model_manager.store.search_by_attr(
loaded_model: LoadedModel = model_manager.load_model_by_attr(
model_name=model_name, base_model=base_model, model_type=model_type
)
loaded_model: LoadedModel = model_manager.load.load_model(configs[0])
return loaded_model
# Install the requested model.
@ -54,7 +53,7 @@ def install_and_load_model(
assert job.complete
try:
loaded_model = model_manager.load.load_model(job.config_out)
loaded_model = model_manager.load_model_by_config(job.config_out)
return loaded_model
except UnknownModelException as e:
raise Exception(

View File

@ -20,6 +20,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueService
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
from invokeai.app.services.model_install import ModelInstallService
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager import (
@ -412,7 +413,7 @@ def get_config_store() -> ModelRecordServiceSQL:
assert output_path is not None
image_files = DiskImageFileStorage(output_path / "images")
db = init_db(config=config, logger=InvokeAILogger.get_logger(), image_files=image_files)
return ModelRecordServiceSQL(db)
return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
def get_model_merger(record_store: ModelRecordServiceBase) -> ModelMerger:

View File

@ -746,7 +746,6 @@
"delete": "Delete",
"deleteConfig": "Delete Config",
"deleteModel": "Delete Model",
"deleteModelImage": "Delete Model Image",
"deleteMsg1": "Are you sure you want to delete this model from InvokeAI?",
"deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.",
"description": "Description",
@ -787,10 +786,6 @@
"modelDeleteFailed": "Failed to delete model",
"modelEntryDeleted": "Model Entry Deleted",
"modelExists": "Model Exists",
"modelImageDeleted": "Model Image Deleted",
"modelImageDeleteFailed": "Model Image Delete Failed",
"modelImageUpdated": "Model Image Updated",
"modelImageUpdateFailed": "Model Image Update Failed",
"modelLocation": "Model Location",
"modelLocationValidationMsg": "Provide the path to a local folder where your Diffusers Model is stored",
"modelManager": "Model Manager",
@ -823,7 +818,6 @@
"oliveModels": "Olives",
"onnxModels": "Onnx",
"path": "Path",
"pathToConfig": "Path To Config",
"pathToCustomConfig": "Path To Custom Config",
"pickModelType": "Pick Model Type",
"predictionType": "Prediction Type",
@ -858,7 +852,6 @@
"triggerPhrases": "Trigger Phrases",
"typePhraseHere": "Type phrase here",
"upcastAttention": "Upcast Attention",
"uploadImage": "Upload Image",
"updateModel": "Update Model",
"useCustomConfig": "Use Custom Config",
"useDefaultSettings": "Use Default Settings",
@ -951,7 +944,6 @@
"doesNotExist": "does not exist",
"downloadWorkflow": "Download Workflow JSON",
"edge": "Edge",
"edit": "Edit",
"editMode": "Edit in Workflow Editor",
"enum": "Enum",
"enumDescription": "Enums are values that may be one of a number of options.",
@ -1027,7 +1019,6 @@
"nodeTemplate": "Node Template",
"nodeType": "Node Type",
"noFieldsLinearview": "No fields added to Linear View",
"noFieldsViewMode": "This workflow has no selected fields to display. View the full workflow to configure values.",
"noFieldType": "No field type",
"noImageFoundState": "No initial image found in state",
"noMatchingNodes": "No matching nodes",
@ -1815,7 +1806,6 @@
"cursorPosition": "Cursor Position",
"darkenOutsideSelection": "Darken Outside Selection",
"discardAll": "Discard All",
"discardCurrent": "Discard Current",
"downloadAsImage": "Download As Image",
"emptyFolder": "Empty Folder",
"emptyTempImageFolder": "Empty Temp Image Folder",
@ -1825,7 +1815,6 @@
"eraseBoundingBox": "Erase Bounding Box",
"eraser": "Eraser",
"fillBoundingBox": "Fill Bounding Box",
"invertBrushSizeScrollDirection": "Invert Scroll for Brush Size",
"layer": "Layer",
"limitStrokesToBox": "Limit Strokes to Box",
"mask": "Mask",

View File

@ -115,8 +115,7 @@
"safetensors": "Safetensors",
"ai": "ia",
"file": "File",
"toResolve": "Da risolvere",
"add": "Aggiungi"
"toResolve": "Da risolvere"
},
"gallery": {
"generations": "Generazioni",
@ -154,12 +153,7 @@
"starImage": "Immagine preferita",
"dropToUpload": "$t(gallery.drop) per aggiornare",
"problemDeletingImagesDesc": "Impossibile eliminare una o più immagini",
"problemDeletingImages": "Problema durante l'eliminazione delle immagini",
"bulkDownloadRequested": "Preparazione del download",
"bulkDownloadRequestedDesc": "La tua richiesta di download è in preparazione. L'operazione potrebbe richiedere alcuni istanti.",
"bulkDownloadRequestFailed": "Problema durante la preparazione del download",
"bulkDownloadStarting": "Avvio scaricamento",
"bulkDownloadFailed": "Scaricamento fallito"
"problemDeletingImages": "Problema durante l'eliminazione delle immagini"
},
"hotkeys": {
"keyboardShortcuts": "Tasti di scelta rapida",
@ -511,12 +505,12 @@
"modelSyncFailed": "Sincronizzazione modello non riuscita",
"settings": "Impostazioni",
"syncModels": "Sincronizza Modelli",
"syncModelsDesc": "Se i tuoi modelli non sono sincronizzati con il back-end, puoi aggiornarli utilizzando questa opzione. Questo è generalmente utile nei casi in cui aggiungi modelli alla cartella principale di InvokeAI dopo l'avvio dell'applicazione.",
"syncModelsDesc": "Se i tuoi modelli non sono sincronizzati con il back-end, puoi aggiornarli utilizzando questa opzione. Questo è generalmente utile nei casi in cui aggiorni manualmente il tuo file models.yaml o aggiungi modelli alla cartella principale di InvokeAI dopo l'avvio dell'applicazione.",
"loraModels": "LoRA",
"oliveModels": "Olive",
"onnxModels": "ONNX",
"noModels": "Nessun modello trovato",
"predictionType": "Tipo di previsione",
"predictionType": "Tipo di previsione (per modelli Stable Diffusion 2.x ed alcuni modelli Stable Diffusion 1.x)",
"quickAdd": "Aggiunta rapida",
"simpleModelDesc": "Fornire un percorso a un modello diffusori locale, un modello checkpoint/safetensor locale, un ID repository HuggingFace o un URL del modello checkpoint/diffusori.",
"advanced": "Avanzate",
@ -527,34 +521,7 @@
"vaePrecision": "Precisione VAE",
"noModelSelected": "Nessun modello selezionato",
"conversionNotSupported": "Conversione non supportata",
"configFile": "File di configurazione",
"modelName": "Nome del modello",
"modelSettings": "Impostazioni del modello",
"advancedImportInfo": "La scheda opzioni avanzate consente la configurazione manuale delle impostazioni del modello principale. Utilizza questa scheda solo se sei sicuro di conoscere il tipo di modello e la configurazione corretti per il modello selezionato.",
"addAll": "Aggiungi tutto",
"addModels": "Aggiungi modelli",
"cancel": "Annulla",
"edit": "Modifica",
"imageEncoderModelId": "ID modello codificatore di immagini",
"importQueue": "Coda di importazione",
"modelMetadata": "Metadati del modello",
"path": "Percorso",
"prune": "Elimina",
"pruneTooltip": "Elimina dalla coda le importazioni completate",
"removeFromQueue": "Rimuovi dalla coda",
"repoVariant": "Variante del repository",
"scan": "Scansiona",
"scanFolder": "Scansione cartella",
"scanResults": "Risultati della scansione",
"source": "Sorgente",
"upcastAttention": "Eleva l'attenzione",
"ztsnrTraining": "Addestramento ZTSNR",
"typePhraseHere": "Digita la frase qui",
"defaultSettingsSaved": "Impostazioni predefinite salvate",
"defaultSettings": "Impostazioni predefinite",
"metadata": "Metadati",
"useDefaultSettings": "Usa le impostazioni predefinite",
"triggerPhrases": "Frasi trigger"
"configFile": "File di configurazione"
},
"parameters": {
"images": "Immagini",
@ -636,8 +603,8 @@
"clipSkip": "CLIP Skip",
"aspectRatio": "Proporzioni",
"maskAdjustmentsHeader": "Regolazioni della maschera",
"maskBlur": "Sfocatura maschera",
"maskBlurMethod": "Metodo sfocatura maschera",
"maskBlur": "Sfocatura",
"maskBlurMethod": "Metodo di sfocatura",
"seamLowThreshold": "Basso",
"seamHighThreshold": "Alto",
"coherencePassHeader": "Passaggio di coerenza",
@ -694,8 +661,7 @@
"setToOptimalSizeTooLarge": "$t(parameters.setToOptimalSize) (potrebbe essere troppo grande)",
"boxBlur": "Box",
"gaussianBlur": "Gaussian",
"remixImage": "Remixa l'immagine",
"coherenceEdgeSize": "Dimensione bordo"
"remixImage": "Remixa l'immagine"
},
"settings": {
"models": "Modelli",
@ -778,8 +744,8 @@
"canceled": "Elaborazione annullata",
"problemCopyingImageLink": "Impossibile copiare il collegamento dell'immagine",
"uploadFailedInvalidUploadDesc": "Deve essere una singola immagine PNG o JPEG",
"parameterSet": "{{parameter}} impostato",
"parameterNotSet": "{{parameter}} non impostato",
"parameterSet": "Parametro impostato",
"parameterNotSet": "Parametro non impostato",
"nodesLoadedFailed": "Impossibile caricare i nodi",
"nodesSaved": "Nodi salvati",
"nodesLoaded": "Nodi caricati",
@ -832,10 +798,7 @@
"problemRetrievingWorkflow": "Problema nel recupero del flusso di lavoro",
"resetInitialImage": "Reimposta l'immagine iniziale",
"uploadInitialImage": "Carica l'immagine iniziale",
"problemDownloadingImage": "Impossibile scaricare l'immagine",
"prunedQueue": "Coda ripulita",
"modelImportCanceled": "Importazione del modello annullata",
"modelImportRemoved": "Importazione del modello rimossa"
"problemDownloadingImage": "Impossibile scaricare l'immagine"
},
"tooltip": {
"feature": {
@ -913,10 +876,7 @@
"antialiasing": "Anti aliasing",
"showResultsOn": "Mostra i risultati (attivato)",
"showResultsOff": "Mostra i risultati (disattivato)",
"saveMask": "Salva $t(unifiedCanvas.mask)",
"coherenceModeGaussianBlur": "Sfocatura Gaussiana",
"coherenceModeBoxBlur": "Sfocatura Box",
"coherenceModeStaged": "Maschera espansa"
"saveMask": "Salva $t(unifiedCanvas.mask)"
},
"accessibility": {
"modelSelect": "Seleziona modello",
@ -1385,8 +1345,7 @@
"allLoRAsAdded": "Tutti i LoRA aggiunti",
"defaultVAE": "VAE predefinito",
"incompatibleBaseModel": "Modello base incompatibile",
"loraAlreadyAdded": "LoRA già aggiunto",
"concepts": "Concetti"
"loraAlreadyAdded": "LoRA già aggiunto"
},
"invocationCache": {
"disable": "Disabilita",
@ -1739,25 +1698,6 @@
"paragraphs": [
"Valuta le generazioni in modo che siano più simili alle immagini con un punteggio estetico elevato, in base ai dati di addestramento."
]
},
"compositingCoherenceMinDenoise": {
"heading": "Livello minimo di riduzione del rumore",
"paragraphs": [
"Intensità minima di riduzione rumore per la modalità di Coerenza",
"L'intensità minima di riduzione del rumore per la regione di coerenza durante l'inpainting o l'outpainting"
]
},
"compositingMaskBlur": {
"paragraphs": [
"Il raggio di sfocatura della maschera."
],
"heading": "Sfocatura maschera"
},
"compositingCoherenceEdgeSize": {
"heading": "Dimensione del bordo",
"paragraphs": [
"La dimensione del bordo del passaggio di coerenza."
]
}
},
"sdxl": {
@ -1806,12 +1746,7 @@
"scheduler": "Campionatore",
"recallParameters": "Richiama i parametri",
"noRecallParameters": "Nessun parametro da richiamare trovato",
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
"allPrompts": "Tutti i prompt",
"imageDimensions": "Dimensioni dell'immagine",
"parameterSet": "Parametro {{parameter}} impostato",
"parsingFailed": "Analisi non riuscita",
"recallParameter": "Richiama {{label}}"
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)"
},
"hrf": {
"enableHrf": "Abilita Correzione Alta Risoluzione",
@ -1883,11 +1818,5 @@
"image": {
"title": "Immagine"
}
},
"prompt": {
"compatibleEmbeddings": "Incorporamenti compatibili",
"addPromptTrigger": "Aggiungi parola chiave nel prompt",
"noPromptTriggers": "Nessuna parola chiave disponibile",
"noMatchingTriggers": "Nessuna parola chiave corrispondente"
}
}

View File

@ -52,7 +52,7 @@
"accept": "Принять",
"postprocessing": "Постобработка",
"txt2img": "Текст в изображение (txt2img)",
"linear": "Линейный вид",
"linear": "Линейная обработка",
"dontAskMeAgain": "Больше не спрашивать",
"areYouSure": "Вы уверены?",
"random": "Случайное",
@ -117,8 +117,7 @@
"toResolve": "Чтоб решить",
"copy": "Копировать",
"localSystem": "Локальная система",
"aboutDesc": "Используя Invoke для работы? Проверьте это:",
"add": "Добавить"
"aboutDesc": "Используя Invoke для работы? Проверьте это:"
},
"gallery": {
"generations": "Генерации",
@ -156,12 +155,7 @@
"noImageSelected": "Изображение не выбрано",
"setCurrentImage": "Установить как текущее изображение",
"starImage": "Добавить в избранное",
"dropToUpload": "$t(gallery.drop) чтоб загрузить",
"bulkDownloadFailed": "Загрузка не удалась",
"bulkDownloadStarting": "Начало загрузки",
"bulkDownloadRequested": "Подготовка к скачиванию",
"bulkDownloadRequestedDesc": "Ваш запрос на скачивание готовится. Это может занять несколько минут.",
"bulkDownloadRequestFailed": "Возникла проблема при подготовке скачивания"
"dropToUpload": "$t(gallery.drop) чтоб загрузить"
},
"hotkeys": {
"keyboardShortcuts": "Горячие клавиши",
@ -510,7 +504,7 @@
"settings": "Настройки",
"selectModel": "Выберите модель",
"syncModels": "Синхронизация моделей",
"syncModelsDesc": "Если ваши модели не синхронизированы с серверной частью, вы можете обновить их с помощью этой опции. Обычно это удобно в тех случаях, когда вы добавляете модели в корневую папку InvokeAI или каталог автоимпорта после загрузки приложения.",
"syncModelsDesc": "Если ваши модели не синхронизированы с серверной частью, вы можете обновить их, используя эту опцию. Обычно это удобно в тех случаях, когда вы вручную обновляете свой файл \"models.yaml\" или добавляете модели в корневую папку InvokeAI после загрузки приложения.",
"modelUpdateFailed": "Не удалось обновить модель",
"modelConversionFailed": "Не удалось сконвертировать модель",
"modelsMergeFailed": "Не удалось выполнить слияние моделей",
@ -519,7 +513,7 @@
"oliveModels": "Модели Olives",
"conversionNotSupported": "Преобразование не поддерживается",
"noModels": "Нет моделей",
"predictionType": "Тип прогноза",
"predictionType": "Тип прогноза (для моделей Stable Diffusion 2.x и периодических моделей Stable Diffusion 1.x)",
"quickAdd": "Быстрое добавление",
"simpleModelDesc": "Укажите путь к локальной модели Diffusers , локальной модели checkpoint / safetensors, идентификатор репозитория HuggingFace или URL-адрес модели контрольной checkpoint / diffusers.",
"advanced": "Продвинутый",
@ -530,33 +524,7 @@
"customConfigFileLocation": "Расположение пользовательского файла конфигурации",
"vaePrecision": "Точность VAE",
"noModelSelected": "Модель не выбрана",
"configFile": "Файл конфигурации",
"addAll": "Добавить всё",
"addModels": "Добавить модели",
"cancel": "Отмена",
"defaultSettings": "Стандартные настройки",
"importQueue": "Импортировать очередь",
"metadata": "Метаданные",
"imageEncoderModelId": "ID модели-энкодера изображений",
"typePhraseHere": "Введите фразы здесь",
"advancedImportInfo": "Вкладка «Дополнительно» позволяет вручную настроить основные параметры модели. Используйте эту вкладку только в том случае, если вы уверены, что знаете правильный тип модели и конфигурацию выбранной модели.",
"defaultSettingsSaved": "Стандартные настройки сохранены",
"edit": "Редактировать",
"path": "Путь",
"prune": "Удалить",
"pruneTooltip": "Удалить готовые импорты из очереди",
"removeFromQueue": "Удалить из очереди",
"repoVariant": "Вариант репозитория",
"scan": "Сканировать",
"scanFolder": "Сканировать папку",
"scanResults": "Результаты сканирования",
"source": "Источник",
"triggerPhrases": "Триггерные фразы",
"useDefaultSettings": "Использовать стандартные настройки",
"modelMetadata": "Метаданные модели",
"modelName": "Название модели",
"modelSettings": "Настройки модели",
"upcastAttention": "Внимание"
"configFile": "Файл конфигурации"
},
"parameters": {
"images": "Изображения",
@ -623,7 +591,7 @@
"hSymmetryStep": "Шаг гор. симметрии",
"hidePreview": "Скрыть предпросмотр",
"imageToImage": "Изображение в изображение",
"denoisingStrength": "Сила зашумления",
"denoisingStrength": "Сила шумоподавления",
"copyImage": "Скопировать изображение",
"showPreview": "Показать предпросмотр",
"noiseSettings": "Шум",
@ -638,8 +606,8 @@
"clipSkip": "CLIP Пропуск",
"aspectRatio": "Соотношение",
"maskAdjustmentsHeader": "Настройка маски",
"maskBlur": "Размытие маски",
"maskBlurMethod": "Метод размытия маски",
"maskBlur": "Размытие",
"maskBlurMethod": "Метод размытия",
"seamLowThreshold": "Низкий",
"seamHighThreshold": "Высокий",
"coherencePassHeader": "Порог Coherence",
@ -698,9 +666,7 @@
"lockAspectRatio": "Заблокировать соотношение",
"boxBlur": "Размытие прямоугольника",
"gaussianBlur": "Размытие по Гауссу",
"remixImage": "Ремикс изображения",
"coherenceMinDenoise": "Мин. шумоподавление",
"coherenceEdgeSize": "Размер края"
"remixImage": "Ремикс изображения"
},
"settings": {
"models": "Модели",
@ -783,8 +749,8 @@
"canceled": "Обработка отменена",
"problemCopyingImageLink": "Не удалось скопировать ссылку на изображение",
"uploadFailedInvalidUploadDesc": "Должно быть одно изображение в формате PNG или JPEG",
"parameterNotSet": "Параметр {{parameter}} не задан",
"parameterSet": "Параметр {{parameter}} задан",
"parameterNotSet": "Параметр не задан",
"parameterSet": "Параметр задан",
"nodesLoaded": "Узлы загружены",
"problemCopyingImage": "Не удается скопировать изображение",
"nodesLoadedFailed": "Не удалось загрузить Узлы",
@ -837,10 +803,7 @@
"problemImportingMask": "Проблема с импортом маски",
"problemDownloadingImage": "Не удается скачать изображение",
"uploadInitialImage": "Загрузить начальное изображение",
"resetInitialImage": "Сбросить начальное изображение",
"prunedQueue": "Урезанная очередь",
"modelImportCanceled": "Импорт модели отменен",
"modelImportRemoved": "Импорт модели удален"
"resetInitialImage": "Сбросить начальное изображение"
},
"tooltip": {
"feature": {
@ -1182,11 +1145,7 @@
"reorderLinearView": "Изменить порядок линейного просмотра",
"viewMode": "Использовать в линейном представлении",
"editMode": "Открыть в редакторе узлов",
"resetToDefaultValue": "Сбросить к стандартному значкнию",
"latentsField": "Латенты",
"latentsCollectionDescription": "Латенты могут передаваться между узлами.",
"latentsPolymorphicDescription": "Латенты могут передаваться между узлами.",
"latentsFieldDescription": "Латенты могут передаваться между узлами."
"resetToDefaultValue": "Сбросить к стандартному значкнию"
},
"controlnet": {
"amult": "a_mult",
@ -1335,8 +1294,7 @@
},
"paramScheduler": {
"paragraphs": [
"Планировщик, используемый в процессе генерации.",
"Каждый планировщик определяет, как итеративно добавлять шум к изображению или как обновлять образец на основе выходных данных модели."
"Планировщик определяет, как итеративно добавлять шум к изображению или как обновлять образец на основе выходных данных модели."
],
"heading": "Планировщик"
},
@ -1389,7 +1347,7 @@
"compositingCoherenceMode": {
"heading": "Режим",
"paragraphs": [
"Метод, используемый для создания связного изображения с вновь созданной замаскированной областью."
"Режим прохождения когерентности."
]
},
"paramSeed": {
@ -1407,7 +1365,7 @@
},
"controlNetBeginEnd": {
"paragraphs": [
"Часть процесса шумоподавления, к которой будет применен адаптер контроля.",
"На каких этапах процесса шумоподавления будет применена ControlNet.",
"ControlNet, применяемые в начале процесса, направляют композицию, а ControlNet, применяемые в конце, направляют детали."
],
"heading": "Процент начала/конца шага"
@ -1423,8 +1381,8 @@
},
"clipSkip": {
"paragraphs": [
"Сколько слоев модели CLIP пропустить.",
"Некоторые модели лучше подходят для использования с CLIP Skip."
"Выберите, сколько слоев модели CLIP нужно пропустить.",
"Некоторые модели работают лучше с определенными настройками пропуска CLIP."
],
"heading": "CLIP пропуск"
},
@ -1521,25 +1479,6 @@
"paragraphs": [
"Более высокий вес LoRA приведет к большему влиянию на конечное изображение."
]
},
"compositingMaskBlur": {
"heading": "Размытие маски",
"paragraphs": [
"Радиус размытия маски."
]
},
"compositingCoherenceMinDenoise": {
"heading": "Минимальное шумоподавление",
"paragraphs": [
"Минимальный уровень шумоподавления для режима Coherence",
"Минимальный уровень шумоподавления для области когерентности при перерисовывании или дорисовке"
]
},
"compositingCoherenceEdgeSize": {
"heading": "Размер края",
"paragraphs": [
"Размер края прохода когерентности."
]
}
},
"metadata": {
@ -1570,12 +1509,7 @@
"steps": "Шаги",
"scheduler": "Планировщик",
"noRecallParameters": "Параметры для вызова не найдены",
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
"parameterSet": "Параметр {{parameter}} установлен",
"parsingFailed": "Не удалось выполнить синтаксический анализ",
"recallParameter": "Отозвать {{label}}",
"allPrompts": "Все запросы",
"imageDimensions": "Размеры изображения"
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)"
},
"queue": {
"status": "Статус",
@ -1654,11 +1588,10 @@
"denoisingStrength": "Шумоподавление",
"refinermodel": "Модель перерисовщик",
"posAestheticScore": "Положительная эстетическая оценка",
"concatPromptStyle": "Связывание запроса и стиля",
"concatPromptStyle": "Объединение запроса и стиля",
"loading": "Загрузка...",
"steps": "Шаги",
"posStylePrompt": "Запрос стиля",
"freePromptStyle": "Ручной запрос стиля"
"posStylePrompt": "Запрос стиля"
},
"invocationCache": {
"useCache": "Использовать кэш",
@ -1745,8 +1678,7 @@
"allLoRAsAdded": "Все LoRA добавлены",
"defaultVAE": "Стандартное VAE",
"incompatibleBaseModel": "Несовместимая базовая модель",
"loraAlreadyAdded": "LoRA уже добавлена",
"concepts": "Концепты"
"loraAlreadyAdded": "LoRA уже добавлена"
},
"app": {
"storeNotInitialized": "Магазин не инициализирован"
@ -1764,7 +1696,7 @@
},
"generation": {
"title": "Генерация",
"conceptsTab": "LoRA",
"conceptsTab": "Концепты",
"modelTab": "Модель"
},
"advanced": {

View File

@ -5,55 +5,18 @@ import openapiTS from 'openapi-typescript';
const OPENAPI_URL = 'http://127.0.0.1:9090/openapi.json';
const OUTPUT_FILE = 'src/services/api/schema.ts';
async function generateTypes(schema) {
process.stdout.write(`Generating types ${OUTPUT_FILE}...`);
const types = await openapiTS(schema, {
async function main() {
process.stdout.write(`Generating types "${OPENAPI_URL}" --> "${OUTPUT_FILE}"...`);
const types = await openapiTS(OPENAPI_URL, {
exportType: true,
transform: (schemaObject) => {
if ('format' in schemaObject && schemaObject.format === 'binary') {
return schemaObject.nullable ? 'Blob | null' : 'Blob';
}
if (schemaObject.title === 'MetadataField') {
// This is `Record<string, never>` by default, but it actually accepts any a dict of any valid JSON value.
return 'Record<string, unknown>';
}
},
});
fs.writeFileSync(OUTPUT_FILE, types);
process.stdout.write(`\nOK!\r\n`);
}
async function main() {
const encoding = 'utf-8';
if (process.stdin.isTTY) {
// Handle generating types with an arg (e.g. URL or path to file)
if (process.argv.length > 3) {
console.error('Usage: typegen.js <openapi.json>');
process.exit(1);
}
if (process.argv[2]) {
const schema = new Buffer.from(process.argv[2], encoding);
generateTypes(schema);
} else {
generateTypes(OPENAPI_URL);
}
} else {
// Handle generating types from stdin
let schema = '';
process.stdin.setEncoding(encoding);
process.stdin.on('readable', function () {
const chunk = process.stdin.read();
if (chunk !== null) {
schema += chunk;
}
});
process.stdin.on('end', function () {
generateTypes(JSON.parse(schema));
});
}
}
main();

View File

@ -38,7 +38,7 @@ export const addCanvasImageToControlNetListener = (startAppListening: AppStartLi
type: 'image/png',
}),
image_category: 'control',
is_intermediate: true,
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
crop_visible: false,
postUploadAction: {

View File

@ -48,7 +48,7 @@ export const addCanvasMaskToControlNetListener = (startAppListening: AppStartLis
type: 'image/png',
}),
image_category: 'mask',
is_intermediate: true,
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
crop_visible: false,
postUploadAction: {

View File

@ -101,7 +101,7 @@ export const addEnqueueRequestedCanvasListener = (startAppListening: AppStartLis
).unwrap();
}
const graph = await buildCanvasGraph(state, generationMode, canvasInitImage, canvasMaskImage);
const graph = buildCanvasGraph(state, generationMode, canvasInitImage, canvasMaskImage);
log.debug({ graph: parseify(graph) }, `Canvas graph built`);

View File

@ -20,15 +20,15 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
if (model && model.base === 'sdxl') {
if (action.payload.tabName === 'txt2img') {
graph = await buildLinearSDXLTextToImageGraph(state);
graph = buildLinearSDXLTextToImageGraph(state);
} else {
graph = await buildLinearSDXLImageToImageGraph(state);
graph = buildLinearSDXLImageToImageGraph(state);
}
} else {
if (action.payload.tabName === 'txt2img') {
graph = await buildLinearTextToImageGraph(state);
graph = buildLinearTextToImageGraph(state);
} else {
graph = await buildLinearImageToImageGraph(state);
graph = buildLinearImageToImageGraph(state);
}
}

View File

@ -20,7 +20,7 @@ const sx: ChakraProps['sx'] = {
'.react-colorful__hue-pointer': colorPickerPointerStyles,
'.react-colorful__saturation-pointer': colorPickerPointerStyles,
'.react-colorful__alpha-pointer': colorPickerPointerStyles,
gap: 5,
gap: 2,
flexDir: 'column',
};
@ -39,8 +39,8 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
<Flex sx={sx}>
<RgbaColorPicker color={color} onChange={onChange} style={colorPickerStyles} {...rest} />
{withNumberInput && (
<Flex gap={5}>
<FormControl gap={0}>
<Flex>
<FormControl>
<FormLabel>{t('common.red')}</FormLabel>
<CompositeNumberInput
value={color.r}
@ -52,7 +52,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
defaultValue={90}
/>
</FormControl>
<FormControl gap={0}>
<FormControl>
<FormLabel>{t('common.green')}</FormLabel>
<CompositeNumberInput
value={color.g}
@ -64,7 +64,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
defaultValue={90}
/>
</FormControl>
<FormControl gap={0}>
<FormControl>
<FormLabel>{t('common.blue')}</FormLabel>
<CompositeNumberInput
value={color.b}
@ -76,7 +76,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
defaultValue={255}
/>
</FormControl>
<FormControl gap={0}>
<FormControl>
<FormLabel>{t('common.alpha')}</FormLabel>
<CompositeNumberInput
value={color.a}

View File

@ -29,7 +29,7 @@ import { Layer, Stage } from 'react-konva';
import IAICanvasBoundingBoxOverlay from './IAICanvasBoundingBoxOverlay';
import IAICanvasGrid from './IAICanvasGrid';
import IAICanvasIntermediateImage from './IAICanvasIntermediateImage';
import IAICanvasMaskCompositor from './IAICanvasMaskCompositor';
import IAICanvasMaskCompositer from './IAICanvasMaskCompositer';
import IAICanvasMaskLines from './IAICanvasMaskLines';
import IAICanvasObjectRenderer from './IAICanvasObjectRenderer';
import IAICanvasStagingArea from './IAICanvasStagingArea';
@ -176,7 +176,7 @@ const IAICanvas = () => {
</Layer>
<Layer id="mask" visible={isMaskEnabled && !isStaging} listening={false}>
<IAICanvasMaskLines visible={true} listening={false} />
<IAICanvasMaskCompositor listening={false} />
<IAICanvasMaskCompositer listening={false} />
</Layer>
<Layer listening={false}>
<IAICanvasBoundingBoxOverlay />

View File

@ -16,9 +16,9 @@ const canvasMaskCompositerSelector = createMemoizedSelector(selectCanvasSlice, (
};
});
type IAICanvasMaskCompositorProps = RectConfig;
type IAICanvasMaskCompositerProps = RectConfig;
const IAICanvasMaskCompositor = (props: IAICanvasMaskCompositorProps) => {
const IAICanvasMaskCompositer = (props: IAICanvasMaskCompositerProps) => {
const { ...rest } = props;
const { stageCoordinates, stageDimensions } = useAppSelector(canvasMaskCompositerSelector);
@ -89,4 +89,4 @@ const IAICanvasMaskCompositor = (props: IAICanvasMaskCompositorProps) => {
);
};
export default memo(IAICanvasMaskCompositor);
export default memo(IAICanvasMaskCompositer);

View File

@ -5,7 +5,6 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { stagingAreaImageSaved } from 'features/canvas/store/actions';
import {
commitStagingAreaImage,
discardStagedImage,
discardStagedImages,
nextStagingAreaImage,
prevStagingAreaImage,
@ -23,7 +22,6 @@ import {
PiEyeBold,
PiEyeSlashBold,
PiFloppyDiskBold,
PiTrashSimpleBold,
PiXBold,
} from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
@ -46,40 +44,6 @@ const selector = createMemoizedSelector(selectCanvasSlice, (canvas) => {
};
});
const ClearStagingIntermediatesIconButton = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleDiscardStagingArea = useCallback(() => {
dispatch(discardStagedImages());
}, [dispatch]);
const handleDiscardStagingImage = useCallback(() => {
dispatch(discardStagedImage());
}, [dispatch]);
return (
<>
<IconButton
tooltip={`${t('unifiedCanvas.discardCurrent')}`}
aria-label={t('unifiedCanvas.discardCurrent')}
icon={<PiXBold />}
onClick={handleDiscardStagingImage}
colorScheme="invokeBlue"
fontSize={16}
/>
<IconButton
tooltip={`${t('unifiedCanvas.discardAll')} (Esc)`}
aria-label={t('unifiedCanvas.discardAll')}
icon={<PiTrashSimpleBold />}
onClick={handleDiscardStagingArea}
colorScheme="error"
fontSize={16}
/>
</>
);
};
const IAICanvasStagingAreaToolbar = () => {
const dispatch = useAppDispatch();
const { currentStagingAreaImage, shouldShowStagingImage, currentIndex, total } = useAppSelector(selector);
@ -221,7 +185,14 @@ const IAICanvasStagingAreaToolbar = () => {
onClick={handleSaveToGallery}
colorScheme="invokeBlue"
/>
<ClearStagingIntermediatesIconButton />
<IconButton
tooltip={`${t('unifiedCanvas.discardAll')} (Esc)`}
aria-label={t('unifiedCanvas.discardAll')}
icon={<PiXBold />}
onClick={handleDiscardStagingArea}
colorScheme="error"
fontSize={20}
/>
</ButtonGroup>
</Flex>
);

View File

@ -18,7 +18,6 @@ import {
setShouldAutoSave,
setShouldCropToBoundingBoxOnSave,
setShouldDarkenOutsideBoundingBox,
setShouldInvertBrushSizeScrollDirection,
setShouldRestrictStrokesToBox,
setShouldShowCanvasDebugInfo,
setShouldShowGrid,
@ -41,7 +40,6 @@ const IAICanvasSettingsButtonPopover = () => {
const shouldAutoSave = useAppSelector((s) => s.canvas.shouldAutoSave);
const shouldCropToBoundingBoxOnSave = useAppSelector((s) => s.canvas.shouldCropToBoundingBoxOnSave);
const shouldDarkenOutsideBoundingBox = useAppSelector((s) => s.canvas.shouldDarkenOutsideBoundingBox);
const shouldInvertBrushSizeScrollDirection = useAppSelector((s) => s.canvas.shouldInvertBrushSizeScrollDirection);
const shouldShowCanvasDebugInfo = useAppSelector((s) => s.canvas.shouldShowCanvasDebugInfo);
const shouldShowGrid = useAppSelector((s) => s.canvas.shouldShowGrid);
const shouldShowIntermediates = useAppSelector((s) => s.canvas.shouldShowIntermediates);
@ -78,10 +76,6 @@ const IAICanvasSettingsButtonPopover = () => {
(e: ChangeEvent<HTMLInputElement>) => dispatch(setShouldDarkenOutsideBoundingBox(e.target.checked)),
[dispatch]
);
const handleChangeShouldInvertBrushSizeScrollDirection = useCallback(
(e: ChangeEvent<HTMLInputElement>) => dispatch(setShouldInvertBrushSizeScrollDirection(e.target.checked)),
[dispatch]
);
const handleChangeShouldAutoSave = useCallback(
(e: ChangeEvent<HTMLInputElement>) => dispatch(setShouldAutoSave(e.target.checked)),
[dispatch]
@ -150,13 +144,6 @@ const IAICanvasSettingsButtonPopover = () => {
<FormLabel>{t('unifiedCanvas.limitStrokesToBox')}</FormLabel>
<Checkbox isChecked={shouldRestrictStrokesToBox} onChange={handleChangeShouldRestrictStrokesToBox} />
</FormControl>
<FormControl>
<FormLabel>{t('unifiedCanvas.invertBrushSizeScrollDirection')}</FormLabel>
<Checkbox
isChecked={shouldInvertBrushSizeScrollDirection}
onChange={handleChangeShouldInvertBrushSizeScrollDirection}
/>
</FormControl>
<FormControl>
<FormLabel>{t('unifiedCanvas.showCanvasDebugInfo')}</FormLabel>
<Checkbox isChecked={shouldShowCanvasDebugInfo} onChange={handleChangeShouldShowCanvasDebugInfo} />

View File

@ -15,7 +15,6 @@ const useCanvasWheel = (stageRef: MutableRefObject<Konva.Stage | null>) => {
const stageScale = useAppSelector((s) => s.canvas.stageScale);
const isMoveStageKeyHeld = useStore($isMoveStageKeyHeld);
const brushSize = useAppSelector((s) => s.canvas.brushSize);
const shouldInvertBrushSizeScrollDirection = useAppSelector((s) => s.canvas.shouldInvertBrushSizeScrollDirection);
return useCallback(
(e: KonvaEventObject<WheelEvent>) => {
@ -29,16 +28,10 @@ const useCanvasWheel = (stageRef: MutableRefObject<Konva.Stage | null>) => {
// checking for ctrl key is pressed or not,
// so that brush size can be controlled using ctrl + scroll up/down
// Invert the delta if the property is set to true
let delta = e.evt.deltaY;
if (shouldInvertBrushSizeScrollDirection) {
delta = -delta;
}
if ($ctrl.get() || $meta.get()) {
// This equation was derived by fitting a curve to the desired brush sizes and deltas
// see https://github.com/invoke-ai/InvokeAI/pull/5542#issuecomment-1915847565
const targetDelta = Math.sign(delta) * 0.7363 * Math.pow(1.0394, brushSize);
const targetDelta = Math.sign(e.evt.deltaY) * 0.7363 * Math.pow(1.0394, brushSize);
// This needs to be clamped to prevent the delta from getting too large
const finalDelta = clamp(targetDelta, -20, 20);
// The new brush size is also clamped to prevent it from getting too large or small
@ -74,7 +67,7 @@ const useCanvasWheel = (stageRef: MutableRefObject<Konva.Stage | null>) => {
dispatch(setStageCoordinates(newCoordinates));
}
},
[stageRef, isMoveStageKeyHeld, brushSize, dispatch, stageScale, shouldInvertBrushSizeScrollDirection]
[stageRef, isMoveStageKeyHeld, stageScale, dispatch, brushSize]
);
};

View File

@ -65,7 +65,6 @@ const initialCanvasState: CanvasState = {
shouldAutoSave: false,
shouldCropToBoundingBoxOnSave: false,
shouldDarkenOutsideBoundingBox: false,
shouldInvertBrushSizeScrollDirection: false,
shouldLockBoundingBox: false,
shouldPreserveMaskedArea: false,
shouldRestrictStrokesToBox: true,
@ -221,9 +220,6 @@ export const canvasSlice = createSlice({
setShouldDarkenOutsideBoundingBox: (state, action: PayloadAction<boolean>) => {
state.shouldDarkenOutsideBoundingBox = action.payload;
},
setShouldInvertBrushSizeScrollDirection: (state, action: PayloadAction<boolean>) => {
state.shouldInvertBrushSizeScrollDirection = action.payload;
},
clearCanvasHistory: (state) => {
state.pastLayerStates = [];
state.futureLayerStates = [];
@ -292,31 +288,6 @@ export const canvasSlice = createSlice({
state.shouldShowStagingImage = true;
state.batchIds = [];
},
discardStagedImage: (state) => {
const { images, selectedImageIndex } = state.layerState.stagingArea;
state.pastLayerStates.push(cloneDeep(state.layerState));
if (state.pastLayerStates.length > MAX_HISTORY) {
state.pastLayerStates.shift();
}
if (!images.length) {
return;
}
images.splice(selectedImageIndex, 1);
if (selectedImageIndex >= images.length) {
state.layerState.stagingArea.selectedImageIndex = images.length - 1;
}
if (!images.length) {
state.shouldShowStagingImage = false;
state.shouldShowStagingOutline = false;
}
state.futureLayerStates = [];
},
addFillRect: (state) => {
const { boundingBoxCoordinates, boundingBoxDimensions, brushColor } = state;
@ -684,7 +655,6 @@ export const {
commitColorPickerColor,
commitStagingAreaImage,
discardStagedImages,
discardStagedImage,
nextStagingAreaImage,
prevStagingAreaImage,
redo,
@ -704,7 +674,6 @@ export const {
setShouldAutoSave,
setShouldCropToBoundingBoxOnSave,
setShouldDarkenOutsideBoundingBox,
setShouldInvertBrushSizeScrollDirection,
setShouldPreserveMaskedArea,
setShouldShowBoundingBox,
setShouldShowCanvasDebugInfo,

View File

@ -120,7 +120,6 @@ export interface CanvasState {
shouldAutoSave: boolean;
shouldCropToBoundingBoxOnSave: boolean;
shouldDarkenOutsideBoundingBox: boolean;
shouldInvertBrushSizeScrollDirection: boolean;
shouldLockBoundingBox: boolean;
shouldPreserveMaskedArea: boolean;
shouldRestrictStrokesToBox: boolean;

View File

@ -6,7 +6,7 @@ const AutoAddIcon = () => {
const { t } = useTranslation();
return (
<Flex position="absolute" insetInlineEnd={0} top={0} p={1}>
<Badge variant="solid" bg="invokeBlue.400">
<Badge variant="solid" bg="invokeBlue.500">
{t('common.auto')}
</Badge>
</Flex>

View File

@ -173,8 +173,8 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
w="full"
maxW="full"
borderBottomRadius="base"
bg={isSelected ? 'invokeBlue.400' : 'base.600'}
color={isSelected ? 'base.800' : 'base.100'}
bg={isSelected ? 'invokeBlue.500' : 'base.600'}
color={isSelected ? 'base.50' : 'base.100'}
lineHeight="short"
fontSize="xs"
>
@ -193,7 +193,6 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
overflow="hidden"
textOverflow="ellipsis"
noOfLines={1}
color="inherit"
/>
<EditableInput sx={editableInputStyles} />
</Editable>

View File

@ -109,8 +109,8 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
w="full"
maxW="full"
borderBottomRadius="base"
bg={isSelected ? 'invokeBlue.400' : 'base.600'}
color={isSelected ? 'base.800' : 'base.100'}
bg={isSelected ? 'invokeBlue.500' : 'base.600'}
color={isSelected ? 'base.50' : 'base.100'}
lineHeight="short"
fontSize="xs"
fontWeight={isSelected ? 'bold' : 'normal'}

View File

@ -15,7 +15,7 @@ export const MetadataItemView = memo(
return (
<Flex gap={2}>
{onRecall && <RecallButton label={label} onClick={onRecall} isDisabled={isDisabled} />}
<Flex direction={direction} fontSize="sm">
<Flex direction={direction}>
<Text fontWeight="semibold" whiteSpace="pre-wrap" pr={2}>
{label}:
</Text>

View File

@ -1,27 +0,0 @@
import { Box, Image } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
type Props = {
image_url?: string;
};
const ModelImage = ({ image_url }: Props) => {
if (!image_url) {
return <Box height="50px" minWidth="50px" />;
}
return (
<Image
src={image_url}
objectFit="cover"
objectPosition="50% 50%"
height="50px"
width="50px"
minHeight="50px"
minWidth="50px"
borderRadius="base"
/>
);
};
export default typedMemo(ModelImage);

View File

@ -22,8 +22,6 @@ import { PiTrashSimpleBold } from 'react-icons/pi';
import { useDeleteModelsMutation } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import ModelImage from './ModelImage';
type ModelListItemProps = {
model: AnyModelConfig;
};
@ -75,7 +73,6 @@ const ModelListItem = (props: ModelListItemProps) => {
return (
<Flex gap={2} alignItems="center" w="full">
<ModelImage image_url={model.cover_image || ''} />
<Flex
as={Button}
isChecked={isSelected}

View File

@ -1,134 +0,0 @@
import { Box, Button, IconButton, Image } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { typedMemo } from 'common/util/typedMemo';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCallback, useState } from 'react';
import { useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold, PiUploadSimpleBold } from 'react-icons/pi';
import { useDeleteModelImageMutation, useUpdateModelImageMutation } from 'services/api/endpoints/models';
type Props = {
model_key: string | null;
model_image?: string | null;
};
const ModelImageUpload = ({ model_key, model_image }: Props) => {
const dispatch = useAppDispatch();
const [image, setImage] = useState<string | null>(model_image || null);
const { t } = useTranslation();
const [updateModelImage] = useUpdateModelImageMutation();
const [deleteModelImage] = useDeleteModelImageMutation();
const onDropAccepted = useCallback(
(files: File[]) => {
const file = files[0];
if (!file || !model_key) {
return;
}
updateModelImage({ key: model_key, image: file })
.unwrap()
.then(() => {
setImage(URL.createObjectURL(file));
dispatch(
addToast(
makeToast({
title: t('modelManager.modelImageUpdated'),
status: 'success',
})
)
);
})
.catch((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelImageUpdateFailed'),
status: 'error',
})
)
);
});
},
[dispatch, model_key, t, updateModelImage]
);
const handleResetImage = useCallback(() => {
if (!model_key) {
return;
}
setImage(null);
deleteModelImage(model_key)
.unwrap()
.then(() => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelImageDeleted'),
status: 'success',
})
)
);
})
.catch((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelImageDeleteFailed'),
status: 'error',
})
)
);
});
}, [dispatch, model_key, t, deleteModelImage]);
const { getInputProps, getRootProps } = useDropzone({
accept: { 'image/png': ['.png'], 'image/jpeg': ['.jpg', '.jpeg', '.png'] },
onDropAccepted,
noDrag: true,
multiple: false,
});
if (image) {
return (
<Box position="relative">
<Image
src={image}
objectFit="cover"
objectPosition="50% 50%"
height="100px"
width="100px"
minWidth="100px"
borderRadius="base"
/>
<IconButton
position="absolute"
top="1"
right="1"
onClick={handleResetImage}
aria-label={t('modelManager.deleteModelImage')}
tooltip={t('modelManager.deleteModelImage')}
icon={<PiArrowCounterClockwiseBold size={16} />}
size="sm"
variant="link"
_hover={{ color: 'base.100' }}
/>
</Box>
);
}
return (
<>
<Button leftIcon={<PiUploadSimpleBold />} {...getRootProps()} pointerEvents="auto">
{t('modelManager.uploadImage')}
</Button>
<input {...getInputProps()} />
</>
);
};
export default typedMemo(ModelImageUpload);

View File

@ -4,7 +4,6 @@ import { useAppSelector } from 'app/store/storeHooks';
import { useTranslation } from 'react-i18next';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import ModelImageUpload from './Fields/ModelImageUpload';
import { ModelMetadata } from './Metadata/ModelMetadata';
import { ModelAttrView } from './ModelAttrView';
import { ModelEdit } from './ModelEdit';
@ -26,22 +25,19 @@ export const Model = () => {
return (
<>
<Flex alignItems="center" justifyContent="space-between" gap="4" paddingRight="5">
<Flex flexDir="column" gap={1} p={2}>
<Heading as="h2" fontSize="lg">
{data.name}
</Heading>
<Flex flexDir="column" gap={1} p={2}>
<Heading as="h2" fontSize="lg">
{data.name}
</Heading>
{data.source && (
<Text variant="subtext">
{t('modelManager.source')}: {data?.source}
</Text>
)}
<Box mt="4">
<ModelAttrView label="Description" value={data.description} />
</Box>
</Flex>
<ModelImageUpload model_key={selectedModelKey} model_image={data.cover_image} />
{data.source && (
<Text variant="subtext">
{t('modelManager.source')}: {data?.source}
</Text>
)}
<Box mt="4">
<ModelAttrView label="Description" value={data.description} />
</Box>
</Flex>
<Tabs mt="4" h="100%">

View File

@ -129,7 +129,7 @@ export const ModelEdit = () => {
</Flex>
<Flex flexDir="column" gap={3} mt="4">
<Flex gap="4" alignItems="center">
<Flex>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.description')}</FormLabel>
<Textarea fontSize="md" resize="none" {...register('description')} />
@ -145,32 +145,20 @@ export const ModelEdit = () => {
</FormControl>
</Flex>
{data.type === 'main' && data.format === 'checkpoint' && (
<>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
<Input
{...register('config_path', {
validate: (value) => (value && value.trim().length > 3) || 'Must be at least 3 characters',
})}
/>
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.variant')}</FormLabel>
<ModelVariantSelect control={control} />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
<PredictionTypeSelect control={control} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
<Checkbox {...register('upcast_attention')} />
</FormControl>
</Flex>
</>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.variant')}</FormLabel>
<ModelVariantSelect control={control} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
<PredictionTypeSelect control={control} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
<Checkbox {...register('upcast_attention')} />
</FormControl>
</Flex>
)}
</Flex>
</form>

View File

@ -90,26 +90,21 @@ export const ModelView = () => {
<ModelAttrView label={t('common.format')} value={modelData.format} />
<ModelAttrView label={t('modelManager.path')} value={modelData.path} />
</Flex>
{modelData.type === 'main' && modelData.format === 'diffusers' && modelData.repo_variant && (
{modelData.type === 'main' && (
<Flex gap={2}>
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
{modelData.format === 'diffusers' && modelData.repo_variant && (
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
)}
{modelData.format === 'checkpoint' && (
<>
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config_path} />
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
</>
)}
</Flex>
)}
{modelData.type === 'main' && modelData.format === 'checkpoint' && (
<>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config_path} />
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
</Flex>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
</Flex>
</>
)}
{modelData.type === 'ip_adapter' && (
<Flex gap={2}>
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={modelData.image_encoder_model_id} />
@ -117,11 +112,9 @@ export const ModelView = () => {
)}
</Flex>
</Box>
{modelData.type === 'main' && (
<Box layerStyle="second" borderRadius="base" p={3}>
<DefaultSettings />
</Box>
)}
<Box layerStyle="second" borderRadius="base" p={3}>
<DefaultSettings />
</Box>
</Flex>
);
};

View File

@ -1,55 +0,0 @@
import { Button, Flex, Image, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { workflowModeChanged } from 'features/nodes/store/workflowSlice';
import InvokeLogoSVG from 'public/assets/images/invoke-symbol-wht-lrg.svg';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export const EmptyState = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const onClick = useCallback(() => {
dispatch(workflowModeChanged('edit'));
}, [dispatch]);
return (
<Flex
sx={{
w: 'full',
h: 'full',
userSelect: 'none',
}}
>
<Flex
sx={{
alignItems: 'center',
justifyContent: 'center',
borderRadius: 'base',
flexDir: 'column',
gap: 5,
maxW: '230px',
margin: '0 auto',
}}
>
<Image
src={InvokeLogoSVG}
alt="invoke-ai-logo"
opacity={0.2}
mixBlendMode="overlay"
w={16}
h={16}
minW={16}
minH={16}
userSelect="none"
/>
<Text textAlign="center" fontSize="md">
{t('nodes.noFieldsViewMode')}
</Text>
<Button colorScheme="invokeBlue" onClick={onClick}>
{t('nodes.edit')}
</Button>
</Flex>
</Flex>
);
};

View File

@ -7,7 +7,6 @@ import { selectWorkflowSlice } from 'features/nodes/store/workflowSlice';
import { t } from 'i18next';
import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
import { EmptyState } from './EmptyState';
import WorkflowField from './WorkflowField';
const selector = createMemoizedSelector(selectWorkflowSlice, (workflow) => {
@ -31,7 +30,7 @@ export const WorkflowViewMode = () => {
<WorkflowField key={`${nodeId}.${fieldName}`} nodeId={nodeId} fieldName={fieldName} />
))
) : (
<EmptyState />
<IAINoContentFallback label={t('nodes.noFieldsLinearview')} icon={null} />
)}
</Flex>
</ScrollableContent>

View File

@ -12,17 +12,17 @@ const WorkflowPanel = () => {
<Flex layerStyle="first" flexDir="column" w="full" h="full" borderRadius="base" p={2} gap={2}>
<Tabs variant="line" display="flex" w="full" h="full" flexDir="column">
<TabList>
<Tab>{t('common.linear')}</Tab>
<Tab>{t('common.details')}</Tab>
<Tab>{t('common.linear')}</Tab>
<Tab>JSON</Tab>
</TabList>
<TabPanels>
<TabPanel>
<WorkflowLinearTab />
<WorkflowGeneralTab />
</TabPanel>
<TabPanel>
<WorkflowGeneralTab />
<WorkflowLinearTab />
</TabPanel>
<TabPanel>
<WorkflowJSONTab />

View File

@ -55,22 +55,8 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
// #region Model-related schemas
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
const zSubModelType = z.enum([
'unet',
'text_encoder',
'text_encoder_2',
'tokenizer',
'tokenizer_2',
'vae',
'vae_decoder',
'vae_encoder',
'scheduler',
'safety_checker',
]);
const zModelIdentifier = z.object({
key: z.string().min(1),
submodel_type: zSubModelType.nullish(),
});
export const isModelIdentifier = (field: unknown): field is ModelIdentifier =>
zModelIdentifier.safeParse(field).success;

View File

@ -1,22 +1,18 @@
import type { RootState } from 'app/store/store';
import { selectValidControlNets } from 'features/controlAdapters/store/controlAdaptersSlice';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { omit } from 'lodash-es';
import type {
CollectInvocation,
ControlField,
ControlNetInvocation,
CoreMetadataInvocation,
NonNullableGraph,
} from 'services/api/types';
import { isControlNetModelConfig } from 'services/api/types';
import { CONTROL_NET_COLLECT } from './constants';
import { getModelMetadataField, upsertMetadata } from './metadata';
import { upsertMetadata } from './metadata';
export const addControlNetToLinearGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): Promise<void> => {
export const addControlNetToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
const validControlNets = selectValidControlNets(state.controlAdapters).filter(
(ca) => ca.model?.base === state.generation.model?.base
);
@ -43,7 +39,7 @@ export const addControlNetToLinearGraph = async (
},
});
validControlNets.forEach(async (controlNet) => {
validControlNets.forEach((controlNet) => {
if (!controlNet.model) {
return;
}
@ -89,17 +85,7 @@ export const addControlNetToLinearGraph = async (
graph.nodes[controlNetNode.id] = controlNetNode as ControlNetInvocation;
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isControlNetModelConfig);
controlNetMetadata.push({
control_model: getModelMetadataField(modelConfig),
control_weight: weight,
control_mode: controlMode,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
resize_mode: resizeMode,
image: controlNetNode.image,
});
controlNetMetadata.push(omit(controlNetNode, ['id', 'type', 'is_intermediate']) as ControlField);
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },

View File

@ -1,22 +1,18 @@
import type { RootState } from 'app/store/store';
import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { omit } from 'lodash-es';
import type {
CollectInvocation,
CoreMetadataInvocation,
IPAdapterInvocation,
IPAdapterMetadataField,
NonNullableGraph,
} from 'services/api/types';
import { isIPAdapterModelConfig } from 'services/api/types';
import { IP_ADAPTER_COLLECT } from './constants';
import { getModelMetadataField, upsertMetadata } from './metadata';
import { upsertMetadata } from './metadata';
export const addIPAdapterToLinearGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): Promise<void> => {
export const addIPAdapterToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter(
(ca) => ca.model?.base === state.generation.model?.base
);
@ -39,7 +35,7 @@ export const addIPAdapterToLinearGraph = async (
const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = [];
validIPAdapters.forEach(async (ipAdapter) => {
validIPAdapters.forEach((ipAdapter) => {
if (!ipAdapter.model) {
return;
}
@ -62,17 +58,9 @@ export const addIPAdapterToLinearGraph = async (
return;
}
graph.nodes[ipAdapterNode.id] = ipAdapterNode;
graph.nodes[ipAdapterNode.id] = ipAdapterNode as IPAdapterInvocation;
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isIPAdapterModelConfig);
ipAdapterMetdata.push({
weight: weight,
ip_adapter_model: getModelMetadataField(modelConfig),
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
image: ipAdapterNode.image,
});
ipAdapterMetdata.push(omit(ipAdapterNode, ['id', 'type', 'is_intermediate']) as IPAdapterMetadataField);
graph.edges.push({
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },

View File

@ -1,22 +1,16 @@
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { filter, size } from 'lodash-es';
import {
type CoreMetadataInvocation,
isLoRAModelConfig,
type LoRALoaderInvocation,
type NonNullableGraph,
} from 'services/api/types';
import type { CoreMetadataInvocation, LoraLoaderInvocation, NonNullableGraph } from 'services/api/types';
import { CLIP_SKIP, LORA_LOADER, MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, POSITIVE_CONDITIONING } from './constants';
import { getModelMetadataField, upsertMetadata } from './metadata';
import { upsertMetadata } from './metadata';
export const addLoRAsToGraph = async (
export const addLoRAsToGraph = (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string,
modelLoaderNodeId: string = MAIN_MODEL_LOADER
): Promise<void> => {
): void => {
/**
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
@ -45,12 +39,12 @@ export const addLoRAsToGraph = async (
let currentLoraIndex = 0;
const loraMetadata: CoreMetadataInvocation['loras'] = [];
enabledLoRAs.forEach(async (lora) => {
enabledLoRAs.forEach((lora) => {
const { weight } = lora;
const { key } = lora.model;
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
const loraLoaderNode: LoRALoaderInvocation = {
const loraLoaderNode: LoraLoaderInvocation = {
type: 'lora_loader',
id: currentLoraNodeId,
is_intermediate: true,
@ -58,10 +52,8 @@ export const addLoRAsToGraph = async (
weight,
};
const modelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
loraMetadata.push({
model: getModelMetadataField(modelConfig),
model: { key },
weight,
});

View File

@ -1,12 +1,6 @@
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { filter, size } from 'lodash-es';
import {
type CoreMetadataInvocation,
isLoRAModelConfig,
type NonNullableGraph,
type SDXLLoRALoaderInvocation,
} from 'services/api/types';
import type { CoreMetadataInvocation, NonNullableGraph, SDXLLoraLoaderInvocation } from 'services/api/types';
import {
LORA_LOADER,
@ -16,14 +10,14 @@ import {
SDXL_REFINER_INPAINT_CREATE_MASK,
SEAMLESS,
} from './constants';
import { getModelMetadataField, upsertMetadata } from './metadata';
import { upsertMetadata } from './metadata';
export const addSDXLLoRAsToGraph = async (
export const addSDXLLoRAsToGraph = (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string,
modelLoaderNodeId: string = SDXL_MODEL_LOADER
): Promise<void> => {
): void => {
/**
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
@ -61,12 +55,12 @@ export const addSDXLLoRAsToGraph = async (
let lastLoraNodeId = '';
let currentLoraIndex = 0;
enabledLoRAs.forEach(async (lora) => {
enabledLoRAs.forEach((lora) => {
const { weight } = lora;
const { key } = lora.model;
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
const loraLoaderNode: SDXLLoRALoaderInvocation = {
const loraLoaderNode: SDXLLoraLoaderInvocation = {
type: 'sdxl_lora_loader',
id: currentLoraNodeId,
is_intermediate: true,
@ -74,9 +68,7 @@ export const addSDXLLoRAsToGraph = async (
weight,
};
const modelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
loraMetadata.push({ model: getModelMetadataField(modelConfig), weight });
loraMetadata.push({ model: { key }, weight });
// add to graph
graph.nodes[currentLoraNodeId] = loraLoaderNode;

View File

@ -1,11 +1,9 @@
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import {
type CreateDenoiseMaskInvocation,
type ImageDTO,
isRefinerMainModelModelConfig,
type NonNullableGraph,
type SeamlessModeInvocation,
import type {
CreateDenoiseMaskInvocation,
ImageDTO,
NonNullableGraph,
SeamlessModeInvocation,
} from 'services/api/types';
import {
@ -27,16 +25,16 @@ import {
SDXL_REFINER_SEAMLESS,
} from './constants';
import { getSDXLStylePrompts } from './graphBuilderUtils';
import { getModelMetadataField, upsertMetadata } from './metadata';
import { upsertMetadata } from './metadata';
export const addSDXLRefinerToGraph = async (
export const addSDXLRefinerToGraph = (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string,
modelLoaderNodeId?: string,
canvasInitImage?: ImageDTO,
canvasMaskImage?: ImageDTO
): Promise<void> => {
): void => {
const {
refinerModel,
refinerPositiveAestheticScore,
@ -57,10 +55,9 @@ export const addSDXLRefinerToGraph = async (
const fp32 = vaePrecision === 'fp32';
const isUsingScaledDimensions = ['auto', 'manual'].includes(boundingBoxScaleMethod);
const modelConfig = await fetchModelConfigWithTypeGuard(refinerModel.key, isRefinerMainModelModelConfig);
upsertMetadata(graph, {
refiner_model: getModelMetadataField(modelConfig),
refiner_model: refinerModel,
refiner_positive_aesthetic_score: refinerPositiveAestheticScore,
refiner_negative_aesthetic_score: refinerNegativeAestheticScore,
refiner_cfg_scale: refinerCFGScale,

View File

@ -1,22 +1,18 @@
import type { RootState } from 'app/store/store';
import { selectValidT2IAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import {
type CollectInvocation,
type CoreMetadataInvocation,
isT2IAdapterModelConfig,
type NonNullableGraph,
type T2IAdapterInvocation,
import { omit } from 'lodash-es';
import type {
CollectInvocation,
CoreMetadataInvocation,
NonNullableGraph,
T2IAdapterField,
T2IAdapterInvocation,
} from 'services/api/types';
import { T2I_ADAPTER_COLLECT } from './constants';
import { getModelMetadataField, upsertMetadata } from './metadata';
import { upsertMetadata } from './metadata';
export const addT2IAdaptersToLinearGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): Promise<void> => {
export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
const validT2IAdapters = selectValidT2IAdapters(state.controlAdapters).filter(
(ca) => ca.model?.base === state.generation.model?.base
);
@ -37,9 +33,9 @@ export const addT2IAdaptersToLinearGraph = async (
},
});
const t2iAdapterMetadata: CoreMetadataInvocation['t2iAdapters'] = [];
const t2iAdapterMetdata: CoreMetadataInvocation['t2iAdapters'] = [];
validT2IAdapters.forEach(async (t2iAdapter) => {
validT2IAdapters.forEach((t2iAdapter) => {
if (!t2iAdapter.model) {
return;
}
@ -81,18 +77,9 @@ export const addT2IAdaptersToLinearGraph = async (
return;
}
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode;
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode as T2IAdapterInvocation;
const modelConfig = await fetchModelConfigWithTypeGuard(t2iAdapter.model.key, isT2IAdapterModelConfig);
t2iAdapterMetadata.push({
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
resize_mode: resizeMode,
t2i_adapter_model: getModelMetadataField(modelConfig),
weight: weight,
image: t2iAdapterNode.image,
});
t2iAdapterMetdata.push(omit(t2iAdapterNode, ['id', 'type', 'is_intermediate']) as T2IAdapterField);
graph.edges.push({
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },
@ -103,6 +90,6 @@ export const addT2IAdaptersToLinearGraph = async (
});
});
upsertMetadata(graph, { t2iAdapters: t2iAdapterMetadata });
upsertMetadata(graph, { t2iAdapters: t2iAdapterMetdata });
}
};

View File

@ -1,7 +1,5 @@
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import type { ModelMetadataField, NonNullableGraph } from 'services/api/types';
import { isVAEModelConfig } from 'services/api/types';
import type { NonNullableGraph } from 'services/api/types';
import {
CANVAS_IMAGE_TO_IMAGE_GRAPH,
@ -25,13 +23,13 @@ import {
TEXT_TO_IMAGE_GRAPH,
VAE_LOADER,
} from './constants';
import { getModelMetadataField, upsertMetadata } from './metadata';
import { upsertMetadata } from './metadata';
export const addVAEToGraph = async (
export const addVAEToGraph = (
state: RootState,
graph: NonNullableGraph,
modelLoaderNodeId: string = MAIN_MODEL_LOADER
): Promise<void> => {
): void => {
const { vae, seamlessXAxis, seamlessYAxis } = state.generation;
const { boundingBoxScaleMethod } = state.canvas;
const { refinerModel } = state.sdxl;
@ -151,8 +149,6 @@ export const addVAEToGraph = async (
}
if (vae) {
const modelConfig = await fetchModelConfigWithTypeGuard(vae.key, isVAEModelConfig);
const vaeMetadata: ModelMetadataField = getModelMetadataField(modelConfig);
upsertMetadata(graph, { vae: vaeMetadata });
upsertMetadata(graph, { vae });
}
};

View File

@ -10,46 +10,46 @@ import { buildCanvasSDXLOutpaintGraph } from './buildCanvasSDXLOutpaintGraph';
import { buildCanvasSDXLTextToImageGraph } from './buildCanvasSDXLTextToImageGraph';
import { buildCanvasTextToImageGraph } from './buildCanvasTextToImageGraph';
export const buildCanvasGraph = async (
export const buildCanvasGraph = (
state: RootState,
generationMode: 'txt2img' | 'img2img' | 'inpaint' | 'outpaint',
canvasInitImage: ImageDTO | undefined,
canvasMaskImage: ImageDTO | undefined
): Promise<NonNullableGraph> => {
) => {
let graph: NonNullableGraph;
if (generationMode === 'txt2img') {
if (state.generation.model && state.generation.model.base === 'sdxl') {
graph = await buildCanvasSDXLTextToImageGraph(state);
graph = buildCanvasSDXLTextToImageGraph(state);
} else {
graph = await buildCanvasTextToImageGraph(state);
graph = buildCanvasTextToImageGraph(state);
}
} else if (generationMode === 'img2img') {
if (!canvasInitImage) {
throw new Error('Missing canvas init image');
}
if (state.generation.model && state.generation.model.base === 'sdxl') {
graph = await buildCanvasSDXLImageToImageGraph(state, canvasInitImage);
graph = buildCanvasSDXLImageToImageGraph(state, canvasInitImage);
} else {
graph = await buildCanvasImageToImageGraph(state, canvasInitImage);
graph = buildCanvasImageToImageGraph(state, canvasInitImage);
}
} else if (generationMode === 'inpaint') {
if (!canvasInitImage || !canvasMaskImage) {
throw new Error('Missing canvas init and mask images');
}
if (state.generation.model && state.generation.model.base === 'sdxl') {
graph = await buildCanvasSDXLInpaintGraph(state, canvasInitImage, canvasMaskImage);
graph = buildCanvasSDXLInpaintGraph(state, canvasInitImage, canvasMaskImage);
} else {
graph = await buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage);
graph = buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage);
}
} else {
if (!canvasInitImage) {
throw new Error('Missing canvas init image');
}
if (state.generation.model && state.generation.model.base === 'sdxl') {
graph = await buildCanvasSDXLOutpaintGraph(state, canvasInitImage, canvasMaskImage);
graph = buildCanvasSDXLOutpaintGraph(state, canvasInitImage, canvasMaskImage);
} else {
graph = await buildCanvasOutpaintGraph(state, canvasInitImage, canvasMaskImage);
graph = buildCanvasOutpaintGraph(state, canvasInitImage, canvasMaskImage);
}
}

View File

@ -1,13 +1,7 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import {
type ImageDTO,
type ImageToLatentsInvocation,
isNonRefinerMainModelConfig,
type NonNullableGraph,
} from 'services/api/types';
import type { ImageDTO, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -31,15 +25,12 @@ import {
POSITIVE_CONDITIONING,
SEAMLESS,
} from './constants';
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
import { addCoreMetadataNode } from './metadata';
/**
* Builds the Canvas tab's Image to Image graph.
*/
export const buildCanvasImageToImageGraph = async (
state: RootState,
initialImage: ImageDTO
): Promise<NonNullableGraph> => {
export const buildCanvasImageToImageGraph = (state: RootState, initialImage: ImageDTO): NonNullableGraph => {
const log = logger('nodes');
const {
positivePrompt,
@ -315,8 +306,6 @@ export const buildCanvasImageToImageGraph = async (
});
}
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
addCoreMetadataNode(
graph,
{
@ -327,7 +316,7 @@ export const buildCanvasImageToImageGraph = async (
height: !isUsingScaledDimensions ? height : scaledBoundingBoxDimensions.height,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model: getModelMetadataField(modelConfig),
model,
seed,
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
@ -346,17 +335,17 @@ export const buildCanvasImageToImageGraph = async (
}
// add LoRA support
await addLoRAsToGraph(state, graph, DENOISE_LATENTS);
addLoRAsToGraph(state, graph, DENOISE_LATENTS);
// optionally add custom VAE
await addVAEToGraph(state, graph, modelLoaderNodeId);
addVAEToGraph(state, graph, modelLoaderNodeId);
// add controlnet, mutating `graph`
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// Add IP Adapter
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -40,11 +40,11 @@ import { getBoardField, getIsIntermediate } from './graphBuilderUtils';
/**
* Builds the Canvas tab's Inpaint graph.
*/
export const buildCanvasInpaintGraph = async (
export const buildCanvasInpaintGraph = (
state: RootState,
canvasInitImage: ImageDTO,
canvasMaskImage: ImageDTO
): Promise<NonNullableGraph> => {
): NonNullableGraph => {
const log = logger('nodes');
const {
positivePrompt,
@ -414,17 +414,17 @@ export const buildCanvasInpaintGraph = async (
}
// Add VAE
await addVAEToGraph(state, graph, modelLoaderNodeId);
addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// Add IP Adapter
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!

View File

@ -44,11 +44,11 @@ import { getBoardField, getIsIntermediate } from './graphBuilderUtils';
/**
* Builds the Canvas tab's Outpaint graph.
*/
export const buildCanvasOutpaintGraph = async (
export const buildCanvasOutpaintGraph = (
state: RootState,
canvasInitImage: ImageDTO,
canvasMaskImage?: ImageDTO
): Promise<NonNullableGraph> => {
): NonNullableGraph => {
const log = logger('nodes');
const {
positivePrompt,
@ -545,18 +545,18 @@ export const buildCanvasOutpaintGraph = async (
}
// Add VAE
await addVAEToGraph(state, graph, modelLoaderNodeId);
addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// Add IP Adapter
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -1,12 +1,6 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import {
type ImageDTO,
type ImageToLatentsInvocation,
isNonRefinerMainModelConfig,
type NonNullableGraph,
} from 'services/api/types';
import type { ImageDTO, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -32,15 +26,12 @@ import {
SEAMLESS,
} from './constants';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
import { addCoreMetadataNode } from './metadata';
/**
* Builds the Canvas tab's Image to Image graph.
*/
export const buildCanvasSDXLImageToImageGraph = async (
state: RootState,
initialImage: ImageDTO
): Promise<NonNullableGraph> => {
export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage: ImageDTO): NonNullableGraph => {
const log = logger('nodes');
const {
positivePrompt,
@ -316,8 +307,6 @@ export const buildCanvasSDXLImageToImageGraph = async (
});
}
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
addCoreMetadataNode(
graph,
{
@ -328,7 +317,7 @@ export const buildCanvasSDXLImageToImageGraph = async (
height: !isUsingScaledDimensions ? height : scaledBoundingBoxDimensions.height,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model: getModelMetadataField(modelConfig),
model,
seed,
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
@ -349,24 +338,24 @@ export const buildCanvasSDXLImageToImageGraph = async (
// Add Refiner if enabled
if (refinerModel) {
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
}
// optionally add custom VAE
await addVAEToGraph(state, graph, modelLoaderNodeId);
addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// Add IP Adapter
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -41,11 +41,11 @@ import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBu
/**
* Builds the Canvas tab's Inpaint graph.
*/
export const buildCanvasSDXLInpaintGraph = async (
export const buildCanvasSDXLInpaintGraph = (
state: RootState,
canvasInitImage: ImageDTO,
canvasMaskImage: ImageDTO
): Promise<NonNullableGraph> => {
): NonNullableGraph => {
const log = logger('nodes');
const {
positivePrompt,
@ -426,31 +426,24 @@ export const buildCanvasSDXLInpaintGraph = async (
// Add Refiner if enabled
if (refinerModel) {
await addSDXLRefinerToGraph(
state,
graph,
SDXL_DENOISE_LATENTS,
modelLoaderNodeId,
canvasInitImage,
canvasMaskImage
);
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId, canvasInitImage, canvasMaskImage);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
}
// Add VAE
await addVAEToGraph(state, graph, modelLoaderNodeId);
addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// Add IP Adapter
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -45,11 +45,11 @@ import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBu
/**
* Builds the Canvas tab's Outpaint graph.
*/
export const buildCanvasSDXLOutpaintGraph = async (
export const buildCanvasSDXLOutpaintGraph = (
state: RootState,
canvasInitImage: ImageDTO,
canvasMaskImage?: ImageDTO
): Promise<NonNullableGraph> => {
): NonNullableGraph => {
const log = logger('nodes');
const {
positivePrompt,
@ -555,25 +555,25 @@ export const buildCanvasSDXLOutpaintGraph = async (
// Add Refiner if enabled
if (refinerModel) {
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId, canvasInitImage);
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId, canvasInitImage);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
}
// Add VAE
await addVAEToGraph(state, graph, modelLoaderNodeId);
addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// Add IP Adapter
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -1,7 +1,6 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
import type { NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -25,12 +24,12 @@ import {
SEAMLESS,
} from './constants';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
import { addCoreMetadataNode } from './metadata';
/**
* Builds the Canvas tab's Text to Image graph.
*/
export const buildCanvasSDXLTextToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGraph => {
const log = logger('nodes');
const {
positivePrompt,
@ -273,8 +272,6 @@ export const buildCanvasSDXLTextToImageGraph = async (state: RootState): Promise
});
}
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
addCoreMetadataNode(
graph,
{
@ -287,7 +284,7 @@ export const buildCanvasSDXLTextToImageGraph = async (state: RootState): Promise
negative_prompt: negativePrompt,
positive_style_prompt: positiveStylePrompt,
negative_style_prompt: negativeStylePrompt,
model: getModelMetadataField(modelConfig),
model,
seed,
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
@ -304,24 +301,24 @@ export const buildCanvasSDXLTextToImageGraph = async (state: RootState): Promise
// Add Refiner if enabled
if (refinerModel) {
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
}
// add LoRA support
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
// optionally add custom VAE
await addVAEToGraph(state, graph, modelLoaderNodeId);
addVAEToGraph(state, graph, modelLoaderNodeId);
// add controlnet, mutating `graph`
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// Add IP Adapter
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -1,8 +1,7 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
import type { NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -24,12 +23,12 @@ import {
POSITIVE_CONDITIONING,
SEAMLESS,
} from './constants';
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
import { addCoreMetadataNode } from './metadata';
/**
* Builds the Canvas tab's Text to Image graph.
*/
export const buildCanvasTextToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph => {
const log = logger('nodes');
const {
positivePrompt,
@ -263,8 +262,6 @@ export const buildCanvasTextToImageGraph = async (state: RootState): Promise<Non
});
}
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
addCoreMetadataNode(
graph,
{
@ -275,7 +272,7 @@ export const buildCanvasTextToImageGraph = async (state: RootState): Promise<Non
height: !isUsingScaledDimensions ? height : scaledBoundingBoxDimensions.height,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model: getModelMetadataField(modelConfig),
model,
seed,
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
@ -292,17 +289,17 @@ export const buildCanvasTextToImageGraph = async (state: RootState): Promise<Non
}
// optionally add custom VAE
await addVAEToGraph(state, graph, modelLoaderNodeId);
addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// Add IP Adapter
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -1,13 +1,7 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import {
type ImageResizeInvocation,
type ImageToLatentsInvocation,
isNonRefinerMainModelConfig,
type NonNullableGraph,
} from 'services/api/types';
import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -30,12 +24,12 @@ import {
RESIZE,
SEAMLESS,
} from './constants';
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
import { addCoreMetadataNode } from './metadata';
/**
* Builds the Image to Image tab graph.
*/
export const buildLinearImageToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph => {
const log = logger('nodes');
const {
positivePrompt,
@ -313,8 +307,6 @@ export const buildLinearImageToImageGraph = async (state: RootState): Promise<No
});
}
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
addCoreMetadataNode(
graph,
{
@ -325,7 +317,7 @@ export const buildLinearImageToImageGraph = async (state: RootState): Promise<No
width,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model: getModelMetadataField(modelConfig),
model,
seed,
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
@ -344,17 +336,17 @@ export const buildLinearImageToImageGraph = async (state: RootState): Promise<No
}
// optionally add custom VAE
await addVAEToGraph(state, graph, modelLoaderNodeId);
addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// Add IP Adapter
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -1,12 +1,6 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import {
type ImageResizeInvocation,
type ImageToLatentsInvocation,
isNonRefinerMainModelConfig,
type NonNullableGraph,
} from 'services/api/types';
import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -31,12 +25,12 @@ import {
SEAMLESS,
} from './constants';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
import { addCoreMetadataNode } from './metadata';
/**
* Builds the Image to Image tab graph.
*/
export const buildLinearSDXLImageToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableGraph => {
const log = logger('nodes');
const {
positivePrompt,
@ -324,8 +318,6 @@ export const buildLinearSDXLImageToImageGraph = async (state: RootState): Promis
});
}
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
addCoreMetadataNode(
graph,
{
@ -336,7 +328,7 @@ export const buildLinearSDXLImageToImageGraph = async (state: RootState): Promis
width,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model: getModelMetadataField(modelConfig),
model,
seed,
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
@ -357,25 +349,25 @@ export const buildLinearSDXLImageToImageGraph = async (state: RootState): Promis
// Add Refiner if enabled
if (refinerModel) {
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
}
// optionally add custom VAE
await addVAEToGraph(state, graph, modelLoaderNodeId);
addVAEToGraph(state, graph, modelLoaderNodeId);
// Add LoRA Support
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// Add IP Adapter
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -1,7 +1,6 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
import type { NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -24,9 +23,9 @@ import {
SEAMLESS,
} from './constants';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
import { addCoreMetadataNode } from './metadata';
export const buildLinearSDXLTextToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGraph => {
const log = logger('nodes');
const {
positivePrompt,
@ -222,8 +221,6 @@ export const buildLinearSDXLTextToImageGraph = async (state: RootState): Promise
],
};
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
addCoreMetadataNode(
graph,
{
@ -234,7 +231,7 @@ export const buildLinearSDXLTextToImageGraph = async (state: RootState): Promise
width,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model: getModelMetadataField(modelConfig),
model,
seed,
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
@ -253,25 +250,25 @@ export const buildLinearSDXLTextToImageGraph = async (state: RootState): Promise
// Add Refiner if enabled
if (refinerModel) {
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
}
// optionally add custom VAE
await addVAEToGraph(state, graph, modelLoaderNodeId);
addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// add IP Adapter
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -1,8 +1,7 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
import type { NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addHrfToGraph } from './addHrfToGraph';
@ -24,9 +23,9 @@ import {
SEAMLESS,
TEXT_TO_IMAGE_GRAPH,
} from './constants';
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
import { addCoreMetadataNode } from './metadata';
export const buildLinearTextToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph => {
const log = logger('nodes');
const {
positivePrompt,
@ -213,8 +212,6 @@ export const buildLinearTextToImageGraph = async (state: RootState): Promise<Non
],
};
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
addCoreMetadataNode(
graph,
{
@ -225,7 +222,7 @@ export const buildLinearTextToImageGraph = async (state: RootState): Promise<Non
width,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model: getModelMetadataField(modelConfig),
model,
seed,
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
@ -242,18 +239,18 @@ export const buildLinearTextToImageGraph = async (state: RootState): Promise<Non
}
// optionally add custom VAE
await addVAEToGraph(state, graph, modelLoaderNodeId);
addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// add IP Adapter
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// High resolution fix.
if (state.hrf.hrfEnabled) {

View File

@ -1,5 +1,5 @@
import type { JSONObject } from 'common/types';
import type { AnyModelConfig, CoreMetadataInvocation, ModelMetadataField, NonNullableGraph } from 'services/api/types';
import type { CoreMetadataInvocation, NonNullableGraph } from 'services/api/types';
import { METADATA } from './constants';
@ -71,11 +71,3 @@ export const setMetadataReceivingNode = (graph: NonNullableGraph, nodeId: string
},
});
};
export const getModelMetadataField = ({ key, hash, name, base, type }: AnyModelConfig): ModelMetadataField => ({
key,
hash,
name,
base,
type,
});

View File

@ -38,7 +38,6 @@ export const ParamNegativePrompt = memo(() => {
onKeyDown={onKeyDown}
fontSize="sm"
variant="darkFilled"
paddingRight={30}
/>
<PromptOverlayButtonWrapper>
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />

View File

@ -54,7 +54,6 @@ export const ParamPositivePrompt = memo(() => {
minH={28}
onKeyDown={onKeyDown}
variant="darkFilled"
paddingRight={30}
/>
<PromptOverlayButtonWrapper>
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />

View File

@ -41,7 +41,6 @@ export const ParamSDXLNegativeStylePrompt = memo(() => {
onKeyDown={onKeyDown}
fontSize="sm"
variant="darkFilled"
paddingRight={30}
/>
<PromptOverlayButtonWrapper>
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />

View File

@ -38,7 +38,6 @@ export const ParamSDXLPositiveStylePrompt = memo(() => {
onKeyDown={onKeyDown}
fontSize="sm"
variant="darkFilled"
paddingRight={30}
/>
<PromptOverlayButtonWrapper>
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />

View File

@ -109,7 +109,7 @@ export const imagesApi = api.injectEndpoints({
}),
clearIntermediates: build.mutation<number, void>({
query: () => ({ url: buildImagesUrl('intermediates'), method: 'DELETE' }),
invalidatesTags: ['IntermediatesCount', 'InvocationCacheStatus'],
invalidatesTags: ['IntermediatesCount'],
}),
getImageDTO: build.query<ImageDTO, string>({
query: (image_name) => ({ url: buildImagesUrl(`i/${image_name}`) }),
@ -125,7 +125,7 @@ export const imagesApi = api.injectEndpoints({
paths['/api/v1/images/i/{image_name}/workflow']['get']['responses']['200']['content']['application/json'],
string
>({
query: (image_name) => ({ url: buildImagesUrl(`i/${image_name}/workflow`) }),
query: (image_name) => ({ url: `images/i/${image_name}/workflow` }),
providesTags: (result, error, image_name) => [{ type: 'ImageWorkflow', id: image_name }],
keepUnusedDataFor: 86400, // 24 hours
}),

View File

@ -23,14 +23,7 @@ export type UpdateModelArg = {
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
};
export type UpdateModelImageArg = {
key: string;
image: Blob;
};
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
type UpdateModelImageResponse =
paths['/api/v2/models/i/{key}/image']['patch']['responses']['200']['content']['application/json'];
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
@ -40,7 +33,6 @@ type DeleteModelArg = {
key: string;
};
type DeleteModelResponse = void;
type DeleteModelImageResponse = void;
type ConvertMainModelResponse =
paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];
@ -152,18 +144,6 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['Model'],
}),
updateModelImage: build.mutation<UpdateModelImageResponse, UpdateModelImageArg>({
query: ({ key, image }) => {
const formData = new FormData();
formData.append('image', image);
return {
url: buildModelsUrl(`i/${key}/image`),
method: 'PATCH',
body: formData,
};
},
invalidatesTags: ['Model'],
}),
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
query: ({ source }) => {
return {
@ -183,18 +163,6 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['Model'],
}),
deleteModelImage: build.mutation<DeleteModelImageResponse, string>({
query: (key) => {
return {
url: buildModelsUrl(`i/${key}/image`),
method: 'DELETE',
};
},
invalidatesTags: ['Model'],
}),
getModelImage: build.query<string, string>({
query: (key) => buildModelsUrl(`i/${key}/image`),
}),
convertModel: build.mutation<ConvertMainModelResponse, string>({
query: (key) => {
return {
@ -362,9 +330,7 @@ export const {
useGetTextualInversionModelsQuery,
useGetVaeModelsQuery,
useDeleteModelsMutation,
useDeleteModelImageMutation,
useUpdateModelMutation,
useUpdateModelImageMutation,
useInstallModelMutation,
useConvertModelMutation,
useSyncModelsMutation,

File diff suppressed because one or more lines are too long

View File

@ -38,6 +38,7 @@ export type OffsetPaginatedResults_ImageDTO_ = S['OffsetPaginatedResults_ImageDT
export type ModelType = S['ModelType'];
export type SubModelType = S['SubModelType'];
export type BaseModelType = S['BaseModelType'];
export type ControlField = S['ControlField'];
// Model Configs
@ -119,18 +120,17 @@ export type CreateGradientMaskInvocation = S['CreateGradientMaskInvocation'];
export type CanvasPasteBackInvocation = S['CanvasPasteBackInvocation'];
export type NoiseInvocation = S['NoiseInvocation'];
export type DenoiseLatentsInvocation = S['DenoiseLatentsInvocation'];
export type SDXLLoRALoaderInvocation = S['SDXLLoRALoaderInvocation'];
export type SDXLLoraLoaderInvocation = S['SDXLLoraLoaderInvocation'];
export type ImageToLatentsInvocation = S['ImageToLatentsInvocation'];
export type LatentsToImageInvocation = S['LatentsToImageInvocation'];
export type LoRALoaderInvocation = S['LoRALoaderInvocation'];
export type LoraLoaderInvocation = S['LoraLoaderInvocation'];
export type ESRGANInvocation = S['ESRGANInvocation'];
export type ImageNSFWBlurInvocation = S['ImageNSFWBlurInvocation'];
export type ImageWatermarkInvocation = S['ImageWatermarkInvocation'];
export type SeamlessModeInvocation = S['SeamlessModeInvocation'];
export type CoreMetadataInvocation = S['CoreMetadataInvocation'];
// Metadata fields
export type ModelMetadataField = S['ModelMetadataField'];
export type IPAdapterMetadataField = S['IPAdapterMetadataField'];
export type T2IAdapterField = S['T2IAdapterField'];
// ControlNet Nodes
export type ControlNetInvocation = S['ControlNetInvocation'];

View File

@ -33,15 +33,19 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.latent import SchedulerOutput
from invokeai.app.invocations.metadata import MetadataItemField, MetadataItemOutput, MetadataOutput
from invokeai.app.invocations.model import (
CLIPField,
ClipField,
CLIPOutput,
LoRALoaderOutput,
ModelField,
LoraInfo,
LoraLoaderOutput,
LoRAModelField,
MainModelField,
ModelInfo,
ModelLoaderOutput,
SDXLLoRALoaderOutput,
SDXLLoraLoaderOutput,
UNetField,
UNetOutput,
VAEField,
VaeField,
VAEModelField,
VAEOutput,
)
from invokeai.app.invocations.primitives import (
@ -69,8 +73,8 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.model_manager.config import BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.model_management.model_manager import LoadedModelInfo
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
@ -114,16 +118,20 @@ __all__ = [
"MetadataItemOutput",
"MetadataOutput",
# invokeai.app.invocations.model
"ModelField",
"ModelInfo",
"LoraInfo",
"UNetField",
"CLIPField",
"VAEField",
"ClipField",
"VaeField",
"MainModelField",
"LoRAModelField",
"VAEModelField",
"UNetOutput",
"VAEOutput",
"CLIPOutput",
"ModelLoaderOutput",
"LoRALoaderOutput",
"SDXLLoRALoaderOutput",
"LoraLoaderOutput",
"SDXLLoraLoaderOutput",
# invokeai.app.invocations.primitives
"BooleanCollectionOutput",
"BooleanOutput",
@ -158,7 +166,7 @@ __all__ = [
# invokeai.app.services.config.config_default
"InvokeAIAppConfig",
# invokeai.backend.model_management.model_manager
"LoadedModel",
"LoadedModelInfo",
# invokeai.backend.model_management.models.base
"BaseModelType",
"ModelType",

Some files were not shown because too many files have changed in this diff Show More