Compare commits

...

55 Commits

Author SHA1 Message Date
52039a8d91 lint fix 2024-03-07 13:44:22 -05:00
dc7154439d restore redux action for model changed so that model actually changes 2024-03-07 13:43:39 -05:00
4bbe6f3548 fix image processors to work with new cnet/t2i model format 2024-03-07 11:40:07 -05:00
ad70cdfe87 feat: undo/redo discard canvas staged image 2024-03-07 19:24:55 +11:00
549d461107 refactor: 🚨 satisfy the linter 2024-03-07 19:24:55 +11:00
cab3748010 feat: discard current inpaint item 2024-03-07 19:24:55 +11:00
779b3e0e8e tidy(ui): remove npm lockfile 2024-03-06 21:57:41 -05:00
9b48029bc9 tidy(mm): ModelImages service 2024-03-06 21:57:41 -05:00
347f1fd0b7 fix tests 2024-03-06 21:57:41 -05:00
4af5a09a68 cleanup 2024-03-06 21:57:41 -05:00
8df02623f2 cleanup 2024-03-06 21:57:41 -05:00
aa88fadc30 use webp images 2024-03-06 21:57:41 -05:00
8411029d93 get model image url from model config, added thumbnail formatting for images 2024-03-06 21:57:41 -05:00
239b1e8cc7 moved upload image field and added delete image functionality 2024-03-06 21:57:41 -05:00
8a68355926 got model images displaying, still need to clean up types and unused code 2024-03-06 21:57:41 -05:00
86aef9f31d removed modelimage for now 2024-03-06 21:57:41 -05:00
2f6964bfa5 fetching model image, still not working 2024-03-06 21:57:41 -05:00
c1cdfd132b moved model image to edit page, added model_images service 2024-03-06 21:57:41 -05:00
f6bfe5e6f2 created ugly model image upload component 2024-03-06 21:57:41 -05:00
b5a8455b5f translationBot(ui): update translation (Russian)
Currently translated at 94.6% (1431 of 1512 strings)

translationBot(ui): update translation (Russian)

Currently translated at 94.6% (1431 of 1512 strings)

Co-authored-by: Васянатор <ilabulanov339@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ru/
Translation: InvokeAI/Web UI
2024-03-07 11:47:01 +11:00
645ef081ea translationBot(ui): update translation (Italian)
Currently translated at 98.0% (1487 of 1516 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.0% (1482 of 1512 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.0% (1475 of 1505 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-03-07 11:47:01 +11:00
e68d7fa6d7 fix(ui): update types 2024-03-07 10:56:59 +11:00
c5ab1c7ad6 chore(ui): typegen 2024-03-07 10:56:59 +11:00
5a561cab78 fix(ui): typo 2024-03-07 10:56:59 +11:00
132790eebe tidy(nodes): use canonical capitalizations 2024-03-07 10:56:59 +11:00
c57f6ee885 fix(ui): fix metadata for graphs to use new enriched format 2024-03-07 10:56:59 +11:00
d4a2ea68fc chore(ui): typegen 2024-03-07 10:56:59 +11:00
528ac5dd25 refactor(nodes): model identifiers
- All models are identified by a key and optionally a submodel type via new model `ModelField`. Previously, a few model types had their own class, but not all of them. This inconsistency just added complexity without any benefit.
- Update all invocation to use the new format.
- In the node API, models are loaded by key or an instance of `ModelField` as a convenience.
- Add an enriched model schema for metadata. It includes key, hash, name, base and type.
2024-03-07 10:56:59 +11:00
afd9ae7712 tidy(mm): remove convenience methods from high level model manager service
These were added as a hold-me-over for the nodes API changes, no longer needed. A followup commit will fix the nodes API to not rely on these.
2024-03-07 10:56:59 +11:00
4eefed12f0 refactor: 🚨 please the almighty linter 2024-03-07 10:44:40 +11:00
4301a3d6fd feat: invert scroll direction for brush size 2024-03-07 10:44:40 +11:00
99c0662e3f fix(nodes): load config before doing anything else
This was preventing custom nodes from loading if a custom nodes dir was specified

Closes #5862
2024-03-07 10:36:27 +11:00
cdc0d0c182 add config_path to ModelRecordChanges 2024-03-07 10:29:29 +11:00
a00369a67a add config path as field in model update form when model is a checkpoint 2024-03-07 10:29:29 +11:00
4f096ac3ba feat(scripts): add frontend-types to Makefile to generate types 2024-03-07 10:16:44 +11:00
f5e3341465 feat(scripts): add support for file path & stdin to frontend typegen script 2024-03-07 10:16:44 +11:00
474852ef7e feat(scripts): add script to generate openapi schema 2024-03-07 10:16:44 +11:00
b1d72d411e only show default settings on main models 2024-03-07 09:07:43 +11:00
46614ee28f lint fix 2024-03-06 15:06:27 -05:00
b019f9bb8b make sure all metadata in viewer is rendered at correct font size - specifically fixes control adapter metadata being too big 2024-03-06 15:06:27 -05:00
b857692073 update uploads from canvas to controlnet to be intermediates so they do not appear in gallery 2024-03-06 15:06:27 -05:00
90fb7a1a59 move linear tab to be first on workflow edit mode 2024-03-06 15:06:27 -05:00
56fcf6af78 empty state for workflow with no linear fields in view mode 2024-03-06 15:06:27 -05:00
c4fe7e697b add right-padding to prompt textareas so that text does not go behind icons 2024-03-06 15:06:27 -05:00
2fd483dfc8 use base.800 on invokeBlue.400 for all gallery selected states 2024-03-06 15:06:27 -05:00
b9a9507422 update padding in color picker 2024-03-06 15:06:27 -05:00
f2744fd7d1 fix URL for get image workflow (#5882)
Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2024-03-06 12:46:16 -05:00
fe6e879d38 fix(ui): invalidate InvocationCacheStatus query cache after clearing intermediates 2024-03-06 08:14:12 -05:00
d3ab08fe10 tests: add invocation cache tests 2024-03-06 08:14:12 -05:00
b0615bdfd4 fix(nodes): correctly serialize outputs
In order for delete by match to work, we need the whole invocation output to be stringified.

For some reason, the serialization of the output was set to only include the `type` field. It should instead include the whole output.

I don't understand how this ever worked unless pydantic had different serialization behaviour in v1 (though it appears to have been the same).

Closes #5805
2024-03-06 08:14:12 -05:00
bab20467fb fix(nodes): fix invocation cache clear method args 2024-03-06 08:14:12 -05:00
e24624109e fix(nodes): fix invocation cache ABC typing 2024-03-06 08:14:12 -05:00
458e7185b8 fix: 🐛 didn't include renamed file 2024-03-06 20:06:14 +11:00
a95128f5f2 refactor: ✏️ canvas mask compositor naming
changes `...MaskCompositer` spelling to `...MaskCompositor`
2024-03-06 20:06:14 +11:00
46f32c5e3c Remove references to the no longer existing invokeai.app.services.model_metadata package 2024-03-05 19:58:25 -05:00
101 changed files with 2383 additions and 1129 deletions

View File

@ -10,10 +10,11 @@ help:
@echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting"
@echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors"
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
@echo "test" Run the unit tests.
@echo "frontend-install" Install the pnpm modules needed for the front end
@echo "test Run the unit tests."
@echo "frontend-install Install the pnpm modules needed for the front end"
@echo "frontend-build Build the frontend in order to run on localhost:9090"
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
@echo "installer-zip Build the installer .zip file for the current version"
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
@ -53,6 +54,9 @@ frontend-build:
frontend-dev:
cd invokeai/frontend/web && pnpm dev
frontend-typegen:
cd invokeai/frontend/web && python ../../../scripts/generate_openapi_schema.py | pnpm typegen
# Installer zip file
installer-zip:
cd installer && ./create_installer.sh

View File

@ -25,6 +25,7 @@ from ..services.invocation_cache.invocation_cache_memory import MemoryInvocation
from ..services.invocation_services import InvocationServices
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
from ..services.invoker import Invoker
from ..services.model_images.model_images_default import ModelImageFileStorageDisk
from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService
@ -71,6 +72,8 @@ class ApiDependencies:
image_files = DiskImageFileStorage(f"{output_folder}/images")
model_images_folder = config.models_path
db = init_db(config=config, logger=logger, image_files=image_files)
configuration = config
@ -92,6 +95,7 @@ class ApiDependencies:
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
)
download_queue_service = DownloadQueueService(event_bus=events)
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
model_manager = ModelManagerService.build_model_manager(
app_config=configuration,
model_record_service=ModelRecordServiceSQL(db=db),
@ -118,6 +122,7 @@ class ApiDependencies:
images=images,
invocation_cache=invocation_cache,
logger=logger,
model_images=model_images_service,
model_manager=model_manager,
download_queue=download_queue_service,
names=names,

View File

@ -1,12 +1,16 @@
# Copyright (c) 2023 Lincoln D. Stein
"""FastAPI route for model configuration records."""
import io
import pathlib
import shutil
import traceback
from typing import Any, Dict, List, Optional
from fastapi import Body, Path, Query, Response
from fastapi import Body, Path, Query, Response, UploadFile
from fastapi.responses import FileResponse
from fastapi.routing import APIRouter
from PIL import Image
from pydantic import BaseModel, ConfigDict, Field
from starlette.exceptions import HTTPException
from typing_extensions import Annotated
@ -31,6 +35,9 @@ from ..dependencies import ApiDependencies
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
# images are immutable; set a high max-age
IMAGE_MAX_AGE = 31536000
class ModelsList(BaseModel):
"""Return list of configs."""
@ -105,6 +112,9 @@ async def list_model_records(
found_models.extend(
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
)
for model in found_models:
cover_image = ApiDependencies.invoker.services.model_images.get_url(model.key)
model.cover_image = cover_image
return ModelsList(models=found_models)
@ -148,6 +158,8 @@ async def get_model_record(
record_store = ApiDependencies.invoker.services.model_manager.store
try:
config: AnyModelConfig = record_store.get_model(key)
cover_image = ApiDependencies.invoker.services.model_images.get_url(key)
config.cover_image = cover_image
return config
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
@ -266,6 +278,75 @@ async def update_model_record(
return model_response
@model_manager_router.get(
"/i/{key}/image",
operation_id="get_model_image",
responses={
200: {
"description": "The model image was fetched successfully",
},
400: {"description": "Bad request"},
404: {"description": "The model image could not be found"},
},
status_code=200,
)
async def get_model_image(
key: str = Path(description="The name of model image file to get"),
) -> FileResponse:
"""Gets an image file that previews the model"""
try:
path = ApiDependencies.invoker.services.model_images.get_path(key)
response = FileResponse(
path,
media_type="image/png",
filename=key + ".png",
content_disposition_type="inline",
)
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
return response
except Exception:
raise HTTPException(status_code=404)
@model_manager_router.patch(
"/i/{key}/image",
operation_id="update_model_image",
responses={
200: {
"description": "The model image was updated successfully",
},
400: {"description": "Bad request"},
},
status_code=200,
)
async def update_model_image(
key: Annotated[str, Path(description="Unique key of model")],
image: UploadFile,
) -> None:
if not image.content_type or not image.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await image.read()
try:
pil_image = Image.open(io.BytesIO(contents))
except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image")
logger = ApiDependencies.invoker.services.logger
model_images = ApiDependencies.invoker.services.model_images
try:
model_images.save(pil_image, key)
logger.info(f"Updated image for model: {key}")
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
return
@model_manager_router.delete(
"/i/{key}",
operation_id="delete_model",
@ -296,6 +377,29 @@ async def delete_model(
raise HTTPException(status_code=404, detail=str(e))
@model_manager_router.delete(
"/i/{key}/image",
operation_id="delete_model_image",
responses={
204: {"description": "Model image deleted successfully"},
404: {"description": "Model image not found"},
},
status_code=204,
)
async def delete_model_image(
key: str = Path(description="Unique key of model image to remove from model_images directory."),
) -> None:
logger = ApiDependencies.invoker.services.logger
model_images = ApiDependencies.invoker.services.model_images
try:
model_images.delete(key)
logger.info(f"Deleted model image: {key}")
return
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
# @model_manager_router.post(
# "/i/",
# operation_id="add_model_record",
@ -539,7 +643,7 @@ async def convert_model(
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
# loading the model will convert it into a cached diffusers file
model_manager.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
model_manager.load.load_model(model_config, submodel_type=SubModelType.Scheduler)
# Get the path of the converted model from the loader
cache_path = loader.convert_cache.cache_path(key)

View File

@ -2,12 +2,9 @@
# which are imported/used before parse_args() is called will get the default config values instead of the
# values from the command line or config file.
import sys
from contextlib import asynccontextmanager
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.version.invokeai_version import __version__
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
from .services.config import InvokeAIAppConfig
app_config = InvokeAIAppConfig.get_config()
@ -20,6 +17,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
import asyncio
import mimetypes
import socket
from contextlib import asynccontextmanager
from inspect import signature
from pathlib import Path
from typing import Any
@ -40,6 +38,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
# noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies
@ -59,6 +58,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
BaseInvocation,
UIConfigBase,
)
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
if is_mps_available():
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)

View File

@ -20,7 +20,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
from invokeai.backend.util.devices import torch_dtype
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .model import ClipField
from .model import CLIPField
# unconditioned: Optional[torch.Tensor]
@ -46,7 +46,7 @@ class CompelInvocation(BaseInvocation):
description=FieldDescriptions.compel_prompt,
ui_component=UIComponent.Textarea,
)
clip: ClipField = InputField(
clip: CLIPField = InputField(
title="CLIP",
description=FieldDescriptions.clip,
input=Input.Connection,
@ -54,16 +54,16 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
tokenizer_info = context.models.load(self.clip.tokenizer)
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
text_encoder_info = context.models.load(self.clip.text_encoder)
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, CLIPTextModel)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
@ -127,16 +127,16 @@ class SDXLPromptInvocationBase:
def run_clip_compel(
self,
context: InvocationContext,
clip_field: ClipField,
clip_field: CLIPField,
prompt: str,
get_pooled: bool,
lora_prefix: str,
zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
tokenizer_info = context.models.load(clip_field.tokenizer)
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
text_encoder_info = context.models.load(clip_field.text_encoder)
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
@ -163,7 +163,7 @@ class SDXLPromptInvocationBase:
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_field.loras:
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
lora_info = context.models.load(lora.lora)
lora_model = lora_info.model
assert isinstance(lora_model, LoRAModelRaw)
yield (lora_model, lora.weight)
@ -253,8 +253,8 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
crop_left: int = InputField(default=0, description="")
target_width: int = InputField(default=1024, description="")
target_height: int = InputField(default=1024, description="")
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
@ -340,7 +340,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
crop_top: int = InputField(default=0, description="")
crop_left: int = InputField(default=0, description="")
aesthetic_score: float = InputField(default=6.0, description=FieldDescriptions.sdxl_aesthetic)
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
@ -370,10 +370,10 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
@invocation_output("clip_skip_output")
class ClipSkipInvocationOutput(BaseInvocationOutput):
"""Clip skip node output"""
class CLIPSkipInvocationOutput(BaseInvocationOutput):
"""CLIP skip node output"""
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@invocation(
@ -383,15 +383,15 @@ class ClipSkipInvocationOutput(BaseInvocationOutput):
category="conditioning",
version="1.0.0",
)
class ClipSkipInvocation(BaseInvocation):
class CLIPSkipInvocation(BaseInvocation):
"""Skip layers in clip text_encoder model."""
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers)
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
def invoke(self, context: InvocationContext) -> CLIPSkipInvocationOutput:
self.clip.skipped_layers += self.skipped_layers
return ClipSkipInvocationOutput(
return CLIPSkipInvocationOutput(
clip=self.clip,
)

View File

@ -34,6 +34,7 @@ from invokeai.app.invocations.fields import (
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import ModelField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
@ -51,15 +52,9 @@ CONTROLNET_RESIZE_VALUES = Literal[
]
class ControlNetModelField(BaseModel):
"""ControlNet model field"""
key: str = Field(description="Model config record key for the ControlNet model")
class ControlField(BaseModel):
image: ImageField = Field(description="The control image")
control_model: ControlNetModelField = Field(description="The ControlNet model to use")
control_model: ModelField = Field(description="The ControlNet model to use")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
@ -95,7 +90,7 @@ class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
image: ImageField = InputField(description="The control image")
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
control_model: ModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
control_weight: Union[float, List[float]] = InputField(
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
)

View File

@ -228,7 +228,7 @@ class ConditioningField(BaseModel):
# endregion
class MetadataField(RootModel):
class MetadataField(RootModel[dict[str, Any]]):
"""
Pydantic model for metadata with custom root of type dict[str, Any].
Metadata is stored without a strict schema.

View File

@ -11,25 +11,17 @@ from invokeai.app.invocations.baseinvocation import (
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.model import ModelField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType, ModelType
# LS: Consider moving these two classes into model.py
class IPAdapterModelField(BaseModel):
key: str = Field(description="Key to the IP-Adapter model")
class CLIPVisionModelField(BaseModel):
key: str = Field(description="Key to the CLIP Vision image encoder model")
from invokeai.backend.model_manager.config import BaseModelType, IPAdapterConfig, ModelType
class IPAdapterField(BaseModel):
image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).")
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.")
ip_adapter_model: ModelField = Field(description="The IP-Adapter model to use.")
image_encoder_model: ModelField = Field(description="The name of the CLIP image encoder model.")
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
@ -62,7 +54,7 @@ class IPAdapterInvocation(BaseInvocation):
# Inputs
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
ip_adapter_model: IPAdapterModelField = InputField(
ip_adapter_model: ModelField = InputField(
description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
)
@ -90,18 +82,18 @@ class IPAdapterInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
assert isinstance(ip_adapter_info, IPAdapterConfig)
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
image_encoder_models = context.models.search_by_attrs(
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
)
assert len(image_encoder_models) == 1
image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key)
return IPAdapterOutput(
ip_adapter=IPAdapterField(
image=self.image,
ip_adapter_model=self.ip_adapter_model,
image_encoder_model=image_encoder_model,
image_encoder_model=ModelField(key=image_encoder_models[0].key),
weight=self.weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,

View File

@ -26,6 +26,7 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
from PIL import Image, ImageFilter
from pydantic import field_validator
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPVisionModelWithProjection
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.fields import (
@ -75,7 +76,7 @@ from .baseinvocation import (
invocation_output,
)
from .controlnet_image_processors import ControlField
from .model import ModelInfo, UNetField, VaeField
from .model import ModelField, UNetField, VAEField
if choose_torch_device() == torch.device("mps"):
from torch import mps
@ -118,7 +119,7 @@ class SchedulerInvocation(BaseInvocation):
class CreateDenoiseMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
vae: VaeField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
@ -153,7 +154,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
)
if image_tensor is not None:
vae_info = context.models.load(**self.vae.vae.model_dump())
vae_info = context.models.load(self.vae.vae)
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
@ -244,12 +245,12 @@ class CreateGradientMaskInvocation(BaseInvocation):
def get_scheduler(
context: InvocationContext,
scheduler_info: ModelInfo,
scheduler_info: ModelField,
scheduler_name: str,
seed: int,
) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.models.load(**scheduler_info.model_dump())
orig_scheduler_info = context.models.load(scheduler_info)
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
@ -461,7 +462,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# and if weight is None, populate with default 1.0?
controlnet_data = []
for control_info in control_list:
control_model = exit_stack.enter_context(context.models.load(key=control_info.control_model.key))
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
# control_models.append(control_model)
control_image_field = control_info.image
@ -523,11 +524,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
conditioning_data.ip_adapter_conditioning = []
for single_ip_adapter in ip_adapter:
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.models.load(key=single_ip_adapter.ip_adapter_model.key)
context.models.load(single_ip_adapter.ip_adapter_model)
)
image_encoder_model_info = context.models.load(key=single_ip_adapter.image_encoder_model.key)
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_image_fields = single_ip_adapter.image
if not isinstance(single_ipa_image_fields, list):
@ -538,6 +538,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
with image_encoder_model_info as image_encoder_model:
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
single_ipa_images, image_encoder_model
@ -577,8 +578,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_config = context.models.get_config(key=t2i_adapter_field.t2i_adapter_model.key)
t2i_adapter_loaded_model = context.models.load(key=t2i_adapter_field.t2i_adapter_model.key)
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
image = context.images.get_pil(t2i_adapter_field.image.image_name)
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
@ -731,12 +732,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
return
unet_info = context.models.load(**self.unet.unet.model_dump())
unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
ExitStack() as exit_stack,
@ -830,7 +832,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
description=FieldDescriptions.latents,
input=Input.Connection,
)
vae: VaeField = InputField(
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
@ -841,8 +843,8 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(**self.vae.vae.model_dump())
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL))
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, torch.nn.Module)
latents = latents.to(vae.device)
@ -1008,7 +1010,7 @@ class ImageToLatentsInvocation(BaseInvocation):
image: ImageField = InputField(
description="The image to encode",
)
vae: VaeField = InputField(
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
@ -1064,7 +1066,7 @@ class ImageToLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(**self.vae.vae.model_dump())
vae_info = context.models.load(self.vae.vae)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:

View File

@ -8,7 +8,10 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.controlnet_image_processors import (
CONTROLNET_MODE_VALUES,
CONTROLNET_RESIZE_VALUES,
)
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
@ -17,10 +20,8 @@ from invokeai.app.invocations.fields import (
OutputField,
UIType,
)
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType, ModelType
from ...version import __version__
@ -30,10 +31,20 @@ class MetadataItemField(BaseModel):
value: Any = Field(description=FieldDescriptions.metadata_item_value)
class ModelMetadataField(BaseModel):
"""Model Metadata Field"""
key: str
hash: str
name: str
base: BaseModelType
type: ModelType
class LoRAMetadataField(BaseModel):
"""LoRA Metadata Field"""
model: LoRAModelField = Field(description=FieldDescriptions.lora_model)
model: ModelMetadataField = Field(description=FieldDescriptions.lora_model)
weight: float = Field(description=FieldDescriptions.lora_weight)
@ -41,7 +52,7 @@ class IPAdapterMetadataField(BaseModel):
"""IP Adapter Field, minus the CLIP Vision Encoder model"""
image: ImageField = Field(description="The IP-Adapter image prompt.")
ip_adapter_model: IPAdapterModelField = Field(
ip_adapter_model: ModelMetadataField = Field(
description="The IP-Adapter model.",
)
weight: Union[float, list[float]] = Field(
@ -51,6 +62,33 @@ class IPAdapterMetadataField(BaseModel):
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
class T2IAdapterMetadataField(BaseModel):
image: ImageField = Field(description="The T2I-Adapter image prompt.")
t2i_adapter_model: ModelMetadataField = Field(description="The T2I-Adapter model to use.")
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)"
)
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
class ControlNetMetadataField(BaseModel):
image: ImageField = Field(description="The control image")
control_model: ModelMetadataField = Field(description="The ControlNet model to use")
control_weight: Union[float, list[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@invocation_output("metadata_item_output")
class MetadataItemOutput(BaseInvocationOutput):
"""Metadata Item Output"""
@ -140,14 +178,14 @@ class CoreMetadataInvocation(BaseInvocation):
default=None,
description="The number of skipped CLIP layers",
)
model: Optional[MainModelField] = InputField(default=None, description="The main model used for inference")
controlnets: Optional[list[ControlField]] = InputField(
model: Optional[ModelMetadataField] = InputField(default=None, description="The main model used for inference")
controlnets: Optional[list[ControlNetMetadataField]] = InputField(
default=None, description="The ControlNets used for inference"
)
ipAdapters: Optional[list[IPAdapterMetadataField]] = InputField(
default=None, description="The IP Adapters used for inference"
)
t2iAdapters: Optional[list[T2IAdapterField]] = InputField(
t2iAdapters: Optional[list[T2IAdapterMetadataField]] = InputField(
default=None, description="The IP Adapters used for inference"
)
loras: Optional[list[LoRAMetadataField]] = InputField(default=None, description="The LoRAs used for inference")
@ -159,7 +197,7 @@ class CoreMetadataInvocation(BaseInvocation):
default=None,
description="The name of the initial image",
)
vae: Optional[VAEModelField] = InputField(
vae: Optional[ModelMetadataField] = InputField(
default=None,
description="The VAE used for decoding, if the main model's default was not used",
)
@ -190,7 +228,7 @@ class CoreMetadataInvocation(BaseInvocation):
)
# SDXL Refiner
refiner_model: Optional[MainModelField] = InputField(
refiner_model: Optional[ModelMetadataField] = InputField(
default=None,
description="The SDXL Refiner model used",
)
@ -222,10 +260,9 @@ class CoreMetadataInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> MetadataOutput:
"""Collects and outputs a CoreMetadata object"""
return MetadataOutput(
metadata=MetadataField.model_validate(
self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
)
)
as_dict = self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
as_dict["app_version"] = __version__
return MetadataOutput(metadata=MetadataField.model_validate(as_dict))
model_config = ConfigDict(extra="allow")

View File

@ -6,8 +6,8 @@ from pydantic import BaseModel, Field
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import SubModelType
from ...backend.model_manager import SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -16,33 +16,33 @@ from .baseinvocation import (
)
class ModelInfo(BaseModel):
key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()")
submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
class ModelField(BaseModel):
key: str = Field(description="Key of the model")
submodel_type: Optional[SubModelType] = Field(description="Submodel type", default=None)
class LoraInfo(ModelInfo):
weight: float = Field(description="Lora's weight which to use when apply to model")
class LoRAField(BaseModel):
lora: ModelField = Field(description="Info to load lora model")
weight: float = Field(description="Weight to apply to lora model")
class UNetField(BaseModel):
unet: ModelInfo = Field(description="Info to load unet submodel")
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
unet: ModelField = Field(description="Info to load unet submodel")
scheduler: ModelField = Field(description="Info to load scheduler submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
class ClipField(BaseModel):
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
class CLIPField(BaseModel):
tokenizer: ModelField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelField = Field(description="Info to load text_encoder submodel")
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class VaeField(BaseModel):
# TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel")
class VAEField(BaseModel):
vae: ModelField = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
@ -57,14 +57,14 @@ class UNetOutput(BaseInvocationOutput):
class VAEOutput(BaseInvocationOutput):
"""Base class for invocations that output a VAE field"""
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation_output("clip_output")
class CLIPOutput(BaseInvocationOutput):
"""Base class for invocations that output a CLIP field"""
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
@invocation_output("model_loader_output")
@ -74,18 +74,6 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
pass
class MainModelField(BaseModel):
"""Main model field"""
key: str = Field(description="Model key")
class LoRAModelField(BaseModel):
"""LoRA model field"""
key: str = Field(description="LoRA model key")
@invocation(
"main_model_loader",
title="Main Model",
@ -96,62 +84,40 @@ class LoRAModelField(BaseModel):
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
model: ModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
# TODO: precision?
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
key = self.model.key
# TODO: not found exceptions
if not context.models.exists(key):
raise Exception(f"Unknown model {key}")
if not context.models.exists(self.model.key):
raise Exception(f"Unknown model {self.model.key}")
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return ModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
key=key,
submodel_type=SubModelType.UNet,
),
scheduler=ModelInfo(
key=key,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
key=key,
submodel_type=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
key=key,
submodel_type=SubModelType.TextEncoder,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
key=key,
submodel_type=SubModelType.VAE,
),
),
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
vae=VAEField(vae=vae),
)
@invocation_output("lora_loader_output")
class LoraLoaderOutput(BaseInvocationOutput):
class LoRALoaderOutput(BaseInvocationOutput):
"""Model loader output"""
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.1")
class LoraLoaderInvocation(BaseInvocation):
class LoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField(
default=None,
@ -159,46 +125,41 @@ class LoraLoaderInvocation(BaseInvocation):
input=Input.Connection,
title="UNet",
)
clip: Optional[ClipField] = InputField(
clip: Optional[CLIPField] = InputField(
default=None,
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP",
)
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
def invoke(self, context: InvocationContext) -> LoRALoaderOutput:
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise Exception(f"Unkown lora: {lora_key}!")
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'Lora "{lora_key}" already applied to unet')
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'LoRA "{lora_key}" already applied to unet')
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
raise Exception(f'Lora "{lora_key}" already applied to clip')
if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras):
raise Exception(f'LoRA "{lora_key}" already applied to clip')
output = LoraLoaderOutput()
output = LoRALoaderOutput()
if self.unet is not None:
output.unet = copy.deepcopy(self.unet)
output.unet = self.unet.model_copy(deep=True)
output.unet.loras.append(
LoraInfo(
key=lora_key,
submodel_type=None,
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
if self.clip is not None:
output.clip = copy.deepcopy(self.clip)
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(
LoraInfo(
key=lora_key,
submodel_type=None,
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
@ -207,12 +168,12 @@ class LoraLoaderInvocation(BaseInvocation):
@invocation_output("sdxl_lora_loader_output")
class SDXLLoraLoaderOutput(BaseInvocationOutput):
class SDXLLoRALoaderOutput(BaseInvocationOutput):
"""SDXL LoRA Loader Output"""
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
clip2: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
@invocation(
@ -222,10 +183,10 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput):
category="model",
version="1.0.1",
)
class SDXLLoraLoaderInvocation(BaseInvocation):
class SDXLLoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField(
default=None,
@ -233,65 +194,59 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
input=Input.Connection,
title="UNet",
)
clip: Optional[ClipField] = InputField(
clip: Optional[CLIPField] = InputField(
default=None,
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP 1",
)
clip2: Optional[ClipField] = InputField(
clip2: Optional[CLIPField] = InputField(
default=None,
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP 2",
)
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
def invoke(self, context: InvocationContext) -> SDXLLoRALoaderOutput:
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise Exception(f"Unknown lora: {lora_key}!")
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'Lora "{lora_key}" already applied to unet')
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'LoRA "{lora_key}" already applied to unet')
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
raise Exception(f'Lora "{lora_key}" already applied to clip')
if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras):
raise Exception(f'LoRA "{lora_key}" already applied to clip')
if self.clip2 is not None and any(lora.key == lora_key for lora in self.clip2.loras):
raise Exception(f'Lora "{lora_key}" already applied to clip2')
if self.clip2 is not None and any(lora.lora.key == lora_key for lora in self.clip2.loras):
raise Exception(f'LoRA "{lora_key}" already applied to clip2')
output = SDXLLoraLoaderOutput()
output = SDXLLoRALoaderOutput()
if self.unet is not None:
output.unet = copy.deepcopy(self.unet)
output.unet = self.unet.model_copy(deep=True)
output.unet.loras.append(
LoraInfo(
key=lora_key,
submodel_type=None,
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
if self.clip is not None:
output.clip = copy.deepcopy(self.clip)
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(
LoraInfo(
key=lora_key,
submodel_type=None,
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
if self.clip2 is not None:
output.clip2 = copy.deepcopy(self.clip2)
output.clip2 = self.clip2.model_copy(deep=True)
output.clip2.loras.append(
LoraInfo(
key=lora_key,
submodel_type=None,
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
@ -299,17 +254,11 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
return output
class VAEModelField(BaseModel):
"""Vae model field"""
key: str = Field(description="Model's key")
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1")
class VaeLoaderInvocation(BaseInvocation):
class VAELoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""
vae_model: VAEModelField = InputField(
vae_model: ModelField = InputField(
description=FieldDescriptions.vae_model,
input=Input.Direct,
title="VAE",
@ -321,7 +270,7 @@ class VaeLoaderInvocation(BaseInvocation):
if not context.models.exists(key):
raise Exception(f"Unkown vae: {key}!")
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))
return VAEOutput(vae=VAEField(vae=self.vae_model))
@invocation_output("seamless_output")
@ -329,7 +278,7 @@ class SeamlessModeOutput(BaseInvocationOutput):
"""Modified Seamless Model output"""
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
vae: Optional[VaeField] = OutputField(default=None, description=FieldDescriptions.vae, title="VAE")
vae: Optional[VAEField] = OutputField(default=None, description=FieldDescriptions.vae, title="VAE")
@invocation(
@ -348,7 +297,7 @@ class SeamlessModeInvocation(BaseInvocation):
input=Input.Connection,
title="UNet",
)
vae: Optional[VaeField] = InputField(
vae: Optional[VAEField] = InputField(
default=None,
description=FieldDescriptions.vae_model,
input=Input.Connection,

View File

@ -8,7 +8,7 @@ from .baseinvocation import (
invocation,
invocation_output,
)
from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField
from .model import CLIPField, ModelField, UNetField, VAEField
@invocation_output("sdxl_model_loader_output")
@ -16,9 +16,9 @@ class SDXLModelLoaderOutput(BaseInvocationOutput):
"""SDXL base model loader output"""
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation_output("sdxl_refiner_model_loader_output")
@ -26,15 +26,15 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
"""SDXL refiner model loader output"""
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.1")
class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels."""
model: MainModelField = InputField(
model: ModelField = InputField(
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
)
# TODO: precision?
@ -46,48 +46,19 @@ class SDXLModelLoaderInvocation(BaseInvocation):
if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return SDXLModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
key=model_key,
submodel_type=SubModelType.UNet,
),
scheduler=ModelInfo(
key=model_key,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel_type=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
key=model_key,
submodel_type=SubModelType.TextEncoder,
),
loras=[],
skipped_layers=0,
),
clip2=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel_type=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
key=model_key,
submodel_type=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
key=model_key,
submodel_type=SubModelType.VAE,
),
),
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
vae=VAEField(vae=vae),
)
@ -101,10 +72,8 @@ class SDXLModelLoaderInvocation(BaseInvocation):
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels."""
model: MainModelField = InputField(
description=FieldDescriptions.sdxl_refiner_model,
input=Input.Direct,
ui_type=UIType.SDXLRefinerModel,
model: ModelField = InputField(
description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel
)
# TODO: precision?
@ -115,34 +84,14 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return SDXLRefinerModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
key=model_key,
submodel_type=SubModelType.UNet,
),
scheduler=ModelInfo(
key=model_key,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip2=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel_type=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
key=model_key,
submodel_type=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
key=model_key,
submodel_type=SubModelType.VAE,
),
),
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
vae=VAEField(vae=vae),
)

View File

@ -10,17 +10,14 @@ from invokeai.app.invocations.baseinvocation import (
)
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
from invokeai.app.invocations.model import ModelField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
class T2IAdapterModelField(BaseModel):
key: str = Field(description="Model record key for the T2I-Adapter model")
class T2IAdapterField(BaseModel):
image: ImageField = Field(description="The T2I-Adapter image prompt.")
t2i_adapter_model: T2IAdapterModelField = Field(description="The T2I-Adapter model to use.")
t2i_adapter_model: ModelField = Field(description="The T2I-Adapter model to use.")
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
@ -55,7 +52,7 @@ class T2IAdapterInvocation(BaseInvocation):
# Inputs
image: ImageField = InputField(description="The IP-Adapter image prompt.")
t2i_adapter_model: T2IAdapterModelField = InputField(
t2i_adapter_model: ModelField = InputField(
description="The T2I-Adapter model.",
title="T2I-Adapter Model",
input=Input.Direct,

View File

@ -41,8 +41,9 @@ class InvocationCacheBase(ABC):
"""Clears the cache"""
pass
@staticmethod
@abstractmethod
def create_key(self, invocation: BaseInvocation) -> int:
def create_key(invocation: BaseInvocation) -> int:
"""Gets the key for the invocation's cache item"""
pass

View File

@ -61,9 +61,7 @@ class MemoryInvocationCache(InvocationCacheBase):
self._delete_oldest_access(number_to_delete)
self._cache[key] = CachedItem(
invocation_output,
invocation_output.model_dump_json(
warnings=False, exclude_defaults=True, exclude_unset=True, include={"type"}
),
invocation_output.model_dump_json(warnings=False, exclude_defaults=True, exclude_unset=True),
)
def _delete_oldest_access(self, number_to_delete: int) -> None:
@ -81,7 +79,7 @@ class MemoryInvocationCache(InvocationCacheBase):
with self._lock:
return self._delete(key)
def clear(self, *args, **kwargs) -> None:
def clear(self) -> None:
with self._lock:
if self._max_cache_size == 0:
return

View File

@ -25,6 +25,7 @@ if TYPE_CHECKING:
from .images.images_base import ImageServiceABC
from .invocation_cache.invocation_cache_base import InvocationCacheBase
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
from .model_images.model_images_base import ModelImageFileStorageBase
from .model_manager.model_manager_base import ModelManagerServiceBase
from .names.names_base import NameServiceBase
from .session_processor.session_processor_base import SessionProcessorBase
@ -49,6 +50,7 @@ class InvocationServices:
image_files: "ImageFileStorageBase",
image_records: "ImageRecordStorageBase",
logger: "Logger",
model_images: "ModelImageFileStorageBase",
model_manager: "ModelManagerServiceBase",
download_queue: "DownloadQueueServiceBase",
performance_statistics: "InvocationStatsServiceBase",
@ -72,6 +74,7 @@ class InvocationServices:
self.image_files = image_files
self.image_records = image_records
self.logger = logger
self.model_images = model_images
self.model_manager = model_manager
self.download_queue = download_queue
self.performance_statistics = performance_statistics

View File

@ -0,0 +1,33 @@
from abc import ABC, abstractmethod
from pathlib import Path
from PIL.Image import Image as PILImageType
class ModelImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files."""
@abstractmethod
def get(self, model_key: str) -> PILImageType:
"""Retrieves a model image as PIL Image."""
pass
@abstractmethod
def get_path(self, model_key: str) -> Path:
"""Gets the internal path to a model image."""
pass
@abstractmethod
def get_url(self, model_key: str) -> str | None:
"""Gets the URL to fetch a model image."""
pass
@abstractmethod
def save(self, image: PILImageType, model_key: str) -> None:
"""Saves a model image."""
pass
@abstractmethod
def delete(self, model_key: str) -> None:
"""Deletes a model image."""
pass

View File

@ -0,0 +1,20 @@
# TODO: Should these excpetions subclass existing python exceptions?
class ModelImageFileNotFoundException(Exception):
"""Raised when an image file is not found in storage."""
def __init__(self, message="Model image file not found"):
super().__init__(message)
class ModelImageFileSaveException(Exception):
"""Raised when an image cannot be saved."""
def __init__(self, message="Model image file not saved"):
super().__init__(message)
class ModelImageFileDeleteException(Exception):
"""Raised when an image cannot be deleted."""
def __init__(self, message="Model image file not deleted"):
super().__init__(message)

View File

@ -0,0 +1,79 @@
from pathlib import Path
from PIL import Image
from PIL.Image import Image as PILImageType
from send2trash import send2trash
from invokeai.app.services.invoker import Invoker
from invokeai.app.util.thumbnails import make_thumbnail
from .model_images_base import ModelImageFileStorageBase
from .model_images_common import (
ModelImageFileDeleteException,
ModelImageFileNotFoundException,
ModelImageFileSaveException,
)
class ModelImageFileStorageDisk(ModelImageFileStorageBase):
"""Stores images on disk"""
def __init__(self, model_images_folder: Path):
self._model_images_folder = model_images_folder
self._validate_storage_folders()
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
def get(self, model_key: str) -> PILImageType:
try:
path = self.get_path(model_key)
if not self._validate_path(path):
raise ModelImageFileNotFoundException
return Image.open(path)
except FileNotFoundError as e:
raise ModelImageFileNotFoundException from e
def save(self, image: PILImageType, model_key: str) -> None:
try:
self._validate_storage_folders()
image_path = self._model_images_folder / (model_key + ".webp")
thumbnail = make_thumbnail(image, 256)
thumbnail.save(image_path, format="webp")
except Exception as e:
raise ModelImageFileSaveException from e
def get_path(self, model_key: str) -> Path:
path = self._model_images_folder / (model_key + ".webp")
return path
def get_url(self, model_key: str) -> str | None:
path = self.get_path(model_key)
if not self._validate_path(path):
return
return self._invoker.services.urls.get_model_image_url(model_key)
def delete(self, model_key: str) -> None:
try:
path = self.get_path(model_key)
if not self._validate_path(path):
raise ModelImageFileNotFoundException
send2trash(path)
except Exception as e:
raise ModelImageFileDeleteException from e
def _validate_path(self, path: Path) -> bool:
"""Validates the path given for an image."""
return path.exists()
def _validate_storage_folders(self) -> None:
"""Checks if the required folders exist and create them if they don't"""
self._model_images_folder.mkdir(parents=True, exist_ok=True)

View File

@ -1,15 +1,11 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
from abc import ABC, abstractmethod
from typing import Optional
import torch
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel
from ..config import InvokeAIAppConfig
from ..download import DownloadQueueServiceBase
@ -70,32 +66,3 @@ class ModelManagerServiceBase(ABC):
@abstractmethod
def stop(self, invoker: Invoker) -> None:
pass
@abstractmethod
def load_model_by_config(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
pass
@abstractmethod
def load_model_by_key(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
pass
@abstractmethod
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
pass

View File

@ -1,14 +1,10 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase."""
from typing import Optional
import torch
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, LoadedModel, ModelType, SubModelType
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
@ -18,7 +14,7 @@ from ..download import DownloadQueueServiceBase
from ..events.events_base import EventServiceBase
from ..model_install import ModelInstallService, ModelInstallServiceBase
from ..model_load import ModelLoadService, ModelLoadServiceBase
from ..model_records import ModelRecordServiceBase, UnknownModelException
from ..model_records import ModelRecordServiceBase
from .model_manager_base import ModelManagerServiceBase
@ -64,56 +60,6 @@ class ModelManagerService(ModelManagerServiceBase):
if hasattr(service, "stop"):
service.stop(invoker)
def load_model_by_config(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
return self.load.load_model(model_config, submodel_type, context_data)
def load_model_by_key(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
config = self.store.get_model(key)
return self.load.load_model(config, submodel_type, context_data)
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
"""
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
This is provided for API compatability with the get_model() method
in the original model manager. However, note that LoadedModel is
not the same as the original ModelInfo that ws returned.
:param model_name: Name of to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
:param context: The invocation context.
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
ValueError -- more than one model matches this combination
"""
configs = self.store.search_by_attr(model_name, base_model, model_type)
if len(configs) == 0:
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
elif len(configs) > 1:
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
else:
return self.load.load_model(configs[0], submodel, context_data)
@classmethod
def build_model_manager(
cls,

View File

@ -79,6 +79,7 @@ class ModelRecordChanges(BaseModelExcludeNull):
description="The prediction type of the model.", default=None
)
upcast_attention: Optional[bool] = Field(description="Whether to upcast attention.", default=None)
config_path: Optional[str] = Field(description="Path to config file for model", default=None)
class ModelRecordServiceBase(ABC):
@ -129,6 +130,17 @@ class ModelRecordServiceBase(ABC):
"""
pass
@abstractmethod
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
"""
Retrieve the configuration for the indicated model.
:param hash: Hash of model config to be fetched.
Exceptions: UnknownModelException
"""
pass
@abstractmethod
def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default

View File

@ -203,6 +203,21 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def exists(self, key: str) -> bool:
"""
Return True if a model with the indicated key exists in the databse.

View File

@ -1,7 +1,7 @@
import threading
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union
from PIL.Image import Image
from torch import Tensor
@ -13,15 +13,16 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_records.model_records_base import UnknownModelException
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.model_manager.metadata.metadata_base import AnyModelRepoMetadata
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.invocations.model import ModelField
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
"""
@ -299,22 +300,25 @@ class ConditioningInterface(InvocationContextInterface):
class ModelsInterface(InvocationContextInterface):
def exists(self, key: str) -> bool:
def exists(self, identifier: Union[str, "ModelField"]) -> bool:
"""Checks if a model exists.
Args:
key: The key of the model.
identifier: The key or ModelField representing the model.
Returns:
True if the model exists, False if not.
"""
return self._services.model_manager.store.exists(key)
if isinstance(identifier, str):
return self._services.model_manager.store.exists(identifier)
def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
return self._services.model_manager.store.exists(identifier.key)
def load(self, identifier: Union[str, "ModelField"], submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""Loads a model.
Args:
key: The key of the model.
identifier: The key or ModelField representing the model.
submodel_type: The submodel of the model to get.
Returns:
@ -324,9 +328,13 @@ class ModelsInterface(InvocationContextInterface):
# The model manager emits events as it loads the model. It needs the context data to build
# the event payloads.
return self._services.model_manager.load_model_by_key(
key=key, submodel_type=submodel_type, context_data=self._data
)
if isinstance(identifier, str):
model = self._services.model_manager.store.get_model(identifier)
return self._services.model_manager.load.load_model(model, submodel_type, self._data)
else:
_submodel_type = submodel_type or identifier.submodel_type
model = self._services.model_manager.store.get_model(identifier.key)
return self._services.model_manager.load.load_model(model, _submodel_type, self._data)
def load_by_attrs(
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
@ -343,35 +351,29 @@ class ModelsInterface(InvocationContextInterface):
Returns:
An object representing the loaded model.
"""
return self._services.model_manager.load_model_by_attr(
model_name=name,
base_model=base,
model_type=type,
submodel=submodel_type,
context_data=self._data,
)
def get_config(self, key: str) -> AnyModelConfig:
configs = self._services.model_manager.store.search_by_attr(model_name=name, base_model=base, model_type=type)
if len(configs) == 0:
raise UnknownModelException(f"No model found with name {name}, base {base}, and type {type}")
if len(configs) > 1:
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
def get_config(self, identifier: Union[str, "ModelField"]) -> AnyModelConfig:
"""Gets a model's config.
Args:
key: The key of the model.
identifier: The key or ModelField representing the model.
Returns:
The model's config.
"""
return self._services.model_manager.store.get_model(key=key)
if isinstance(identifier, str):
return self._services.model_manager.store.get_model(identifier)
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
"""Gets a model's metadata, if it has any.
Args:
key: The key of the model.
Returns:
The model's metadata, if it has any.
"""
return self._services.model_manager.store.get_metadata(key=key)
return self._services.model_manager.store.get_model(identifier.key)
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
"""Searches for models by path.

View File

@ -8,3 +8,8 @@ class UrlServiceBase(ABC):
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets the URL for an image or thumbnail."""
pass
@abstractmethod
def get_model_image_url(self, model_key: str) -> str:
"""Gets the URL for a model image"""
pass

View File

@ -4,8 +4,9 @@ from .urls_base import UrlServiceBase
class LocalUrlService(UrlServiceBase):
def __init__(self, base_url: str = "api/v1"):
def __init__(self, base_url: str = "api/v1", base_url_v2: str = "api/v2"):
self._base_url = base_url
self._base_url_v2 = base_url_v2
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
image_basename = os.path.basename(image_name)
@ -15,3 +16,6 @@ class LocalUrlService(UrlServiceBase):
return f"{self._base_url}/images/i/{image_basename}/thumbnail"
return f"{self._base_url}/images/i/{image_basename}/full"
def get_model_image_url(self, model_key: str) -> str:
return f"{self._base_url_v2}/models/i/{model_key}/image"

View File

@ -22,7 +22,7 @@ def generate_ti_list(
for trigger in extract_ti_triggers_from_prompt(prompt):
name_or_key = trigger[1:-1]
try:
loaded_model = context.models.load(key=name_or_key)
loaded_model = context.models.load(name_or_key)
model = loaded_model.model
assert isinstance(model, TextualInversionModelRaw)
assert loaded_model.config.base == base

View File

@ -19,7 +19,6 @@ from invokeai.app.services.model_install import (
ModelInstallService,
ModelInstallServiceBase,
)
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager import (
@ -39,7 +38,7 @@ def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordService
logger = InvokeAILogger.get_logger(config=app_config)
image_files = DiskImageFileStorage(f"{app_config.output_path}/images")
db = init_db(config=app_config, logger=logger, image_files=image_files)
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db)
return obj

View File

@ -161,6 +161,7 @@ class ModelConfigBase(BaseModel):
default_settings: Optional[ModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
cover_image: Optional[str] = Field(description="Url for image to preview model", default=None)
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
@ -310,7 +311,7 @@ class IPAdapterConfig(ModelConfigBase):
class CLIPVisionDiffusersConfig(ModelConfigBase):
"""Model config for ClipVision."""
"""Model config for CLIPVision."""
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
format: Literal[ModelFormat.Diffusers]

View File

@ -24,7 +24,7 @@ from .. import ModelLoader, ModelLoaderRegistry
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
class LoraLoader(ModelLoader):
class LoRALoader(ModelLoader):
"""Class to load LoRA models."""
# We cheat a little bit to get access to the model base

View File

@ -23,7 +23,7 @@ from .generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
class VaeLoader(GenericDiffusersLoader):
class VAELoader(GenericDiffusersLoader):
"""Class to load VAE models."""
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:

View File

@ -42,9 +42,10 @@ def install_and_load_model(
# If the requested model is already installed, return its LoadedModel
with contextlib.suppress(UnknownModelException):
# TODO: Replace with wrapper call
loaded_model: LoadedModel = model_manager.load_model_by_attr(
configs = model_manager.store.search_by_attr(
model_name=model_name, base_model=base_model, model_type=model_type
)
loaded_model: LoadedModel = model_manager.load.load_model(configs[0])
return loaded_model
# Install the requested model.
@ -53,7 +54,7 @@ def install_and_load_model(
assert job.complete
try:
loaded_model = model_manager.load_model_by_config(job.config_out)
loaded_model = model_manager.load.load_model(job.config_out)
return loaded_model
except UnknownModelException as e:
raise Exception(

View File

@ -20,7 +20,6 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueService
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
from invokeai.app.services.model_install import ModelInstallService
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager import (
@ -413,7 +412,7 @@ def get_config_store() -> ModelRecordServiceSQL:
assert output_path is not None
image_files = DiskImageFileStorage(output_path / "images")
db = init_db(config=config, logger=InvokeAILogger.get_logger(), image_files=image_files)
return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
return ModelRecordServiceSQL(db)
def get_model_merger(record_store: ModelRecordServiceBase) -> ModelMerger:

View File

@ -746,6 +746,7 @@
"delete": "Delete",
"deleteConfig": "Delete Config",
"deleteModel": "Delete Model",
"deleteModelImage": "Delete Model Image",
"deleteMsg1": "Are you sure you want to delete this model from InvokeAI?",
"deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.",
"description": "Description",
@ -786,6 +787,10 @@
"modelDeleteFailed": "Failed to delete model",
"modelEntryDeleted": "Model Entry Deleted",
"modelExists": "Model Exists",
"modelImageDeleted": "Model Image Deleted",
"modelImageDeleteFailed": "Model Image Delete Failed",
"modelImageUpdated": "Model Image Updated",
"modelImageUpdateFailed": "Model Image Update Failed",
"modelLocation": "Model Location",
"modelLocationValidationMsg": "Provide the path to a local folder where your Diffusers Model is stored",
"modelManager": "Model Manager",
@ -818,6 +823,7 @@
"oliveModels": "Olives",
"onnxModels": "Onnx",
"path": "Path",
"pathToConfig": "Path To Config",
"pathToCustomConfig": "Path To Custom Config",
"pickModelType": "Pick Model Type",
"predictionType": "Prediction Type",
@ -852,6 +858,7 @@
"triggerPhrases": "Trigger Phrases",
"typePhraseHere": "Type phrase here",
"upcastAttention": "Upcast Attention",
"uploadImage": "Upload Image",
"updateModel": "Update Model",
"useCustomConfig": "Use Custom Config",
"useDefaultSettings": "Use Default Settings",
@ -944,6 +951,7 @@
"doesNotExist": "does not exist",
"downloadWorkflow": "Download Workflow JSON",
"edge": "Edge",
"edit": "Edit",
"editMode": "Edit in Workflow Editor",
"enum": "Enum",
"enumDescription": "Enums are values that may be one of a number of options.",
@ -1019,6 +1027,7 @@
"nodeTemplate": "Node Template",
"nodeType": "Node Type",
"noFieldsLinearview": "No fields added to Linear View",
"noFieldsViewMode": "This workflow has no selected fields to display. View the full workflow to configure values.",
"noFieldType": "No field type",
"noImageFoundState": "No initial image found in state",
"noMatchingNodes": "No matching nodes",
@ -1806,6 +1815,7 @@
"cursorPosition": "Cursor Position",
"darkenOutsideSelection": "Darken Outside Selection",
"discardAll": "Discard All",
"discardCurrent": "Discard Current",
"downloadAsImage": "Download As Image",
"emptyFolder": "Empty Folder",
"emptyTempImageFolder": "Empty Temp Image Folder",
@ -1815,6 +1825,7 @@
"eraseBoundingBox": "Erase Bounding Box",
"eraser": "Eraser",
"fillBoundingBox": "Fill Bounding Box",
"invertBrushSizeScrollDirection": "Invert Scroll for Brush Size",
"layer": "Layer",
"limitStrokesToBox": "Limit Strokes to Box",
"mask": "Mask",

View File

@ -115,7 +115,8 @@
"safetensors": "Safetensors",
"ai": "ia",
"file": "File",
"toResolve": "Da risolvere"
"toResolve": "Da risolvere",
"add": "Aggiungi"
},
"gallery": {
"generations": "Generazioni",
@ -153,7 +154,12 @@
"starImage": "Immagine preferita",
"dropToUpload": "$t(gallery.drop) per aggiornare",
"problemDeletingImagesDesc": "Impossibile eliminare una o più immagini",
"problemDeletingImages": "Problema durante l'eliminazione delle immagini"
"problemDeletingImages": "Problema durante l'eliminazione delle immagini",
"bulkDownloadRequested": "Preparazione del download",
"bulkDownloadRequestedDesc": "La tua richiesta di download è in preparazione. L'operazione potrebbe richiedere alcuni istanti.",
"bulkDownloadRequestFailed": "Problema durante la preparazione del download",
"bulkDownloadStarting": "Avvio scaricamento",
"bulkDownloadFailed": "Scaricamento fallito"
},
"hotkeys": {
"keyboardShortcuts": "Tasti di scelta rapida",
@ -505,12 +511,12 @@
"modelSyncFailed": "Sincronizzazione modello non riuscita",
"settings": "Impostazioni",
"syncModels": "Sincronizza Modelli",
"syncModelsDesc": "Se i tuoi modelli non sono sincronizzati con il back-end, puoi aggiornarli utilizzando questa opzione. Questo è generalmente utile nei casi in cui aggiorni manualmente il tuo file models.yaml o aggiungi modelli alla cartella principale di InvokeAI dopo l'avvio dell'applicazione.",
"syncModelsDesc": "Se i tuoi modelli non sono sincronizzati con il back-end, puoi aggiornarli utilizzando questa opzione. Questo è generalmente utile nei casi in cui aggiungi modelli alla cartella principale di InvokeAI dopo l'avvio dell'applicazione.",
"loraModels": "LoRA",
"oliveModels": "Olive",
"onnxModels": "ONNX",
"noModels": "Nessun modello trovato",
"predictionType": "Tipo di previsione (per modelli Stable Diffusion 2.x ed alcuni modelli Stable Diffusion 1.x)",
"predictionType": "Tipo di previsione",
"quickAdd": "Aggiunta rapida",
"simpleModelDesc": "Fornire un percorso a un modello diffusori locale, un modello checkpoint/safetensor locale, un ID repository HuggingFace o un URL del modello checkpoint/diffusori.",
"advanced": "Avanzate",
@ -521,7 +527,34 @@
"vaePrecision": "Precisione VAE",
"noModelSelected": "Nessun modello selezionato",
"conversionNotSupported": "Conversione non supportata",
"configFile": "File di configurazione"
"configFile": "File di configurazione",
"modelName": "Nome del modello",
"modelSettings": "Impostazioni del modello",
"advancedImportInfo": "La scheda opzioni avanzate consente la configurazione manuale delle impostazioni del modello principale. Utilizza questa scheda solo se sei sicuro di conoscere il tipo di modello e la configurazione corretti per il modello selezionato.",
"addAll": "Aggiungi tutto",
"addModels": "Aggiungi modelli",
"cancel": "Annulla",
"edit": "Modifica",
"imageEncoderModelId": "ID modello codificatore di immagini",
"importQueue": "Coda di importazione",
"modelMetadata": "Metadati del modello",
"path": "Percorso",
"prune": "Elimina",
"pruneTooltip": "Elimina dalla coda le importazioni completate",
"removeFromQueue": "Rimuovi dalla coda",
"repoVariant": "Variante del repository",
"scan": "Scansiona",
"scanFolder": "Scansione cartella",
"scanResults": "Risultati della scansione",
"source": "Sorgente",
"upcastAttention": "Eleva l'attenzione",
"ztsnrTraining": "Addestramento ZTSNR",
"typePhraseHere": "Digita la frase qui",
"defaultSettingsSaved": "Impostazioni predefinite salvate",
"defaultSettings": "Impostazioni predefinite",
"metadata": "Metadati",
"useDefaultSettings": "Usa le impostazioni predefinite",
"triggerPhrases": "Frasi trigger"
},
"parameters": {
"images": "Immagini",
@ -603,8 +636,8 @@
"clipSkip": "CLIP Skip",
"aspectRatio": "Proporzioni",
"maskAdjustmentsHeader": "Regolazioni della maschera",
"maskBlur": "Sfocatura",
"maskBlurMethod": "Metodo di sfocatura",
"maskBlur": "Sfocatura maschera",
"maskBlurMethod": "Metodo sfocatura maschera",
"seamLowThreshold": "Basso",
"seamHighThreshold": "Alto",
"coherencePassHeader": "Passaggio di coerenza",
@ -661,7 +694,8 @@
"setToOptimalSizeTooLarge": "$t(parameters.setToOptimalSize) (potrebbe essere troppo grande)",
"boxBlur": "Box",
"gaussianBlur": "Gaussian",
"remixImage": "Remixa l'immagine"
"remixImage": "Remixa l'immagine",
"coherenceEdgeSize": "Dimensione bordo"
},
"settings": {
"models": "Modelli",
@ -744,8 +778,8 @@
"canceled": "Elaborazione annullata",
"problemCopyingImageLink": "Impossibile copiare il collegamento dell'immagine",
"uploadFailedInvalidUploadDesc": "Deve essere una singola immagine PNG o JPEG",
"parameterSet": "Parametro impostato",
"parameterNotSet": "Parametro non impostato",
"parameterSet": "{{parameter}} impostato",
"parameterNotSet": "{{parameter}} non impostato",
"nodesLoadedFailed": "Impossibile caricare i nodi",
"nodesSaved": "Nodi salvati",
"nodesLoaded": "Nodi caricati",
@ -798,7 +832,10 @@
"problemRetrievingWorkflow": "Problema nel recupero del flusso di lavoro",
"resetInitialImage": "Reimposta l'immagine iniziale",
"uploadInitialImage": "Carica l'immagine iniziale",
"problemDownloadingImage": "Impossibile scaricare l'immagine"
"problemDownloadingImage": "Impossibile scaricare l'immagine",
"prunedQueue": "Coda ripulita",
"modelImportCanceled": "Importazione del modello annullata",
"modelImportRemoved": "Importazione del modello rimossa"
},
"tooltip": {
"feature": {
@ -876,7 +913,10 @@
"antialiasing": "Anti aliasing",
"showResultsOn": "Mostra i risultati (attivato)",
"showResultsOff": "Mostra i risultati (disattivato)",
"saveMask": "Salva $t(unifiedCanvas.mask)"
"saveMask": "Salva $t(unifiedCanvas.mask)",
"coherenceModeGaussianBlur": "Sfocatura Gaussiana",
"coherenceModeBoxBlur": "Sfocatura Box",
"coherenceModeStaged": "Maschera espansa"
},
"accessibility": {
"modelSelect": "Seleziona modello",
@ -1345,7 +1385,8 @@
"allLoRAsAdded": "Tutti i LoRA aggiunti",
"defaultVAE": "VAE predefinito",
"incompatibleBaseModel": "Modello base incompatibile",
"loraAlreadyAdded": "LoRA già aggiunto"
"loraAlreadyAdded": "LoRA già aggiunto",
"concepts": "Concetti"
},
"invocationCache": {
"disable": "Disabilita",
@ -1698,6 +1739,25 @@
"paragraphs": [
"Valuta le generazioni in modo che siano più simili alle immagini con un punteggio estetico elevato, in base ai dati di addestramento."
]
},
"compositingCoherenceMinDenoise": {
"heading": "Livello minimo di riduzione del rumore",
"paragraphs": [
"Intensità minima di riduzione rumore per la modalità di Coerenza",
"L'intensità minima di riduzione del rumore per la regione di coerenza durante l'inpainting o l'outpainting"
]
},
"compositingMaskBlur": {
"paragraphs": [
"Il raggio di sfocatura della maschera."
],
"heading": "Sfocatura maschera"
},
"compositingCoherenceEdgeSize": {
"heading": "Dimensione del bordo",
"paragraphs": [
"La dimensione del bordo del passaggio di coerenza."
]
}
},
"sdxl": {
@ -1746,7 +1806,12 @@
"scheduler": "Campionatore",
"recallParameters": "Richiama i parametri",
"noRecallParameters": "Nessun parametro da richiamare trovato",
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)"
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
"allPrompts": "Tutti i prompt",
"imageDimensions": "Dimensioni dell'immagine",
"parameterSet": "Parametro {{parameter}} impostato",
"parsingFailed": "Analisi non riuscita",
"recallParameter": "Richiama {{label}}"
},
"hrf": {
"enableHrf": "Abilita Correzione Alta Risoluzione",
@ -1818,5 +1883,11 @@
"image": {
"title": "Immagine"
}
},
"prompt": {
"compatibleEmbeddings": "Incorporamenti compatibili",
"addPromptTrigger": "Aggiungi parola chiave nel prompt",
"noPromptTriggers": "Nessuna parola chiave disponibile",
"noMatchingTriggers": "Nessuna parola chiave corrispondente"
}
}

View File

@ -52,7 +52,7 @@
"accept": "Принять",
"postprocessing": "Постобработка",
"txt2img": "Текст в изображение (txt2img)",
"linear": "Линейная обработка",
"linear": "Линейный вид",
"dontAskMeAgain": "Больше не спрашивать",
"areYouSure": "Вы уверены?",
"random": "Случайное",
@ -117,7 +117,8 @@
"toResolve": "Чтоб решить",
"copy": "Копировать",
"localSystem": "Локальная система",
"aboutDesc": "Используя Invoke для работы? Проверьте это:"
"aboutDesc": "Используя Invoke для работы? Проверьте это:",
"add": "Добавить"
},
"gallery": {
"generations": "Генерации",
@ -155,7 +156,12 @@
"noImageSelected": "Изображение не выбрано",
"setCurrentImage": "Установить как текущее изображение",
"starImage": "Добавить в избранное",
"dropToUpload": "$t(gallery.drop) чтоб загрузить"
"dropToUpload": "$t(gallery.drop) чтоб загрузить",
"bulkDownloadFailed": "Загрузка не удалась",
"bulkDownloadStarting": "Начало загрузки",
"bulkDownloadRequested": "Подготовка к скачиванию",
"bulkDownloadRequestedDesc": "Ваш запрос на скачивание готовится. Это может занять несколько минут.",
"bulkDownloadRequestFailed": "Возникла проблема при подготовке скачивания"
},
"hotkeys": {
"keyboardShortcuts": "Горячие клавиши",
@ -504,7 +510,7 @@
"settings": "Настройки",
"selectModel": "Выберите модель",
"syncModels": "Синхронизация моделей",
"syncModelsDesc": "Если ваши модели не синхронизированы с серверной частью, вы можете обновить их, используя эту опцию. Обычно это удобно в тех случаях, когда вы вручную обновляете свой файл \"models.yaml\" или добавляете модели в корневую папку InvokeAI после загрузки приложения.",
"syncModelsDesc": "Если ваши модели не синхронизированы с серверной частью, вы можете обновить их с помощью этой опции. Обычно это удобно в тех случаях, когда вы добавляете модели в корневую папку InvokeAI или каталог автоимпорта после загрузки приложения.",
"modelUpdateFailed": "Не удалось обновить модель",
"modelConversionFailed": "Не удалось сконвертировать модель",
"modelsMergeFailed": "Не удалось выполнить слияние моделей",
@ -513,7 +519,7 @@
"oliveModels": "Модели Olives",
"conversionNotSupported": "Преобразование не поддерживается",
"noModels": "Нет моделей",
"predictionType": "Тип прогноза (для моделей Stable Diffusion 2.x и периодических моделей Stable Diffusion 1.x)",
"predictionType": "Тип прогноза",
"quickAdd": "Быстрое добавление",
"simpleModelDesc": "Укажите путь к локальной модели Diffusers , локальной модели checkpoint / safetensors, идентификатор репозитория HuggingFace или URL-адрес модели контрольной checkpoint / diffusers.",
"advanced": "Продвинутый",
@ -524,7 +530,33 @@
"customConfigFileLocation": "Расположение пользовательского файла конфигурации",
"vaePrecision": "Точность VAE",
"noModelSelected": "Модель не выбрана",
"configFile": "Файл конфигурации"
"configFile": "Файл конфигурации",
"addAll": "Добавить всё",
"addModels": "Добавить модели",
"cancel": "Отмена",
"defaultSettings": "Стандартные настройки",
"importQueue": "Импортировать очередь",
"metadata": "Метаданные",
"imageEncoderModelId": "ID модели-энкодера изображений",
"typePhraseHere": "Введите фразы здесь",
"advancedImportInfo": "Вкладка «Дополнительно» позволяет вручную настроить основные параметры модели. Используйте эту вкладку только в том случае, если вы уверены, что знаете правильный тип модели и конфигурацию выбранной модели.",
"defaultSettingsSaved": "Стандартные настройки сохранены",
"edit": "Редактировать",
"path": "Путь",
"prune": "Удалить",
"pruneTooltip": "Удалить готовые импорты из очереди",
"removeFromQueue": "Удалить из очереди",
"repoVariant": "Вариант репозитория",
"scan": "Сканировать",
"scanFolder": "Сканировать папку",
"scanResults": "Результаты сканирования",
"source": "Источник",
"triggerPhrases": "Триггерные фразы",
"useDefaultSettings": "Использовать стандартные настройки",
"modelMetadata": "Метаданные модели",
"modelName": "Название модели",
"modelSettings": "Настройки модели",
"upcastAttention": "Внимание"
},
"parameters": {
"images": "Изображения",
@ -591,7 +623,7 @@
"hSymmetryStep": "Шаг гор. симметрии",
"hidePreview": "Скрыть предпросмотр",
"imageToImage": "Изображение в изображение",
"denoisingStrength": "Сила шумоподавления",
"denoisingStrength": "Сила зашумления",
"copyImage": "Скопировать изображение",
"showPreview": "Показать предпросмотр",
"noiseSettings": "Шум",
@ -606,8 +638,8 @@
"clipSkip": "CLIP Пропуск",
"aspectRatio": "Соотношение",
"maskAdjustmentsHeader": "Настройка маски",
"maskBlur": "Размытие",
"maskBlurMethod": "Метод размытия",
"maskBlur": "Размытие маски",
"maskBlurMethod": "Метод размытия маски",
"seamLowThreshold": "Низкий",
"seamHighThreshold": "Высокий",
"coherencePassHeader": "Порог Coherence",
@ -666,7 +698,9 @@
"lockAspectRatio": "Заблокировать соотношение",
"boxBlur": "Размытие прямоугольника",
"gaussianBlur": "Размытие по Гауссу",
"remixImage": "Ремикс изображения"
"remixImage": "Ремикс изображения",
"coherenceMinDenoise": "Мин. шумоподавление",
"coherenceEdgeSize": "Размер края"
},
"settings": {
"models": "Модели",
@ -749,8 +783,8 @@
"canceled": "Обработка отменена",
"problemCopyingImageLink": "Не удалось скопировать ссылку на изображение",
"uploadFailedInvalidUploadDesc": "Должно быть одно изображение в формате PNG или JPEG",
"parameterNotSet": "Параметр не задан",
"parameterSet": "Параметр задан",
"parameterNotSet": "Параметр {{parameter}} не задан",
"parameterSet": "Параметр {{parameter}} задан",
"nodesLoaded": "Узлы загружены",
"problemCopyingImage": "Не удается скопировать изображение",
"nodesLoadedFailed": "Не удалось загрузить Узлы",
@ -803,7 +837,10 @@
"problemImportingMask": "Проблема с импортом маски",
"problemDownloadingImage": "Не удается скачать изображение",
"uploadInitialImage": "Загрузить начальное изображение",
"resetInitialImage": "Сбросить начальное изображение"
"resetInitialImage": "Сбросить начальное изображение",
"prunedQueue": "Урезанная очередь",
"modelImportCanceled": "Импорт модели отменен",
"modelImportRemoved": "Импорт модели удален"
},
"tooltip": {
"feature": {
@ -1145,7 +1182,11 @@
"reorderLinearView": "Изменить порядок линейного просмотра",
"viewMode": "Использовать в линейном представлении",
"editMode": "Открыть в редакторе узлов",
"resetToDefaultValue": "Сбросить к стандартному значкнию"
"resetToDefaultValue": "Сбросить к стандартному значкнию",
"latentsField": "Латенты",
"latentsCollectionDescription": "Латенты могут передаваться между узлами.",
"latentsPolymorphicDescription": "Латенты могут передаваться между узлами.",
"latentsFieldDescription": "Латенты могут передаваться между узлами."
},
"controlnet": {
"amult": "a_mult",
@ -1294,7 +1335,8 @@
},
"paramScheduler": {
"paragraphs": [
"Планировщик определяет, как итеративно добавлять шум к изображению или как обновлять образец на основе выходных данных модели."
"Планировщик, используемый в процессе генерации.",
"Каждый планировщик определяет, как итеративно добавлять шум к изображению или как обновлять образец на основе выходных данных модели."
],
"heading": "Планировщик"
},
@ -1347,7 +1389,7 @@
"compositingCoherenceMode": {
"heading": "Режим",
"paragraphs": [
"Режим прохождения когерентности."
"Метод, используемый для создания связного изображения с вновь созданной замаскированной областью."
]
},
"paramSeed": {
@ -1365,7 +1407,7 @@
},
"controlNetBeginEnd": {
"paragraphs": [
"На каких этапах процесса шумоподавления будет применена ControlNet.",
"Часть процесса шумоподавления, к которой будет применен адаптер контроля.",
"ControlNet, применяемые в начале процесса, направляют композицию, а ControlNet, применяемые в конце, направляют детали."
],
"heading": "Процент начала/конца шага"
@ -1381,8 +1423,8 @@
},
"clipSkip": {
"paragraphs": [
"Выберите, сколько слоев модели CLIP нужно пропустить.",
"Некоторые модели работают лучше с определенными настройками пропуска CLIP."
"Сколько слоев модели CLIP пропустить.",
"Некоторые модели лучше подходят для использования с CLIP Skip."
],
"heading": "CLIP пропуск"
},
@ -1479,6 +1521,25 @@
"paragraphs": [
"Более высокий вес LoRA приведет к большему влиянию на конечное изображение."
]
},
"compositingMaskBlur": {
"heading": "Размытие маски",
"paragraphs": [
"Радиус размытия маски."
]
},
"compositingCoherenceMinDenoise": {
"heading": "Минимальное шумоподавление",
"paragraphs": [
"Минимальный уровень шумоподавления для режима Coherence",
"Минимальный уровень шумоподавления для области когерентности при перерисовывании или дорисовке"
]
},
"compositingCoherenceEdgeSize": {
"heading": "Размер края",
"paragraphs": [
"Размер края прохода когерентности."
]
}
},
"metadata": {
@ -1509,7 +1570,12 @@
"steps": "Шаги",
"scheduler": "Планировщик",
"noRecallParameters": "Параметры для вызова не найдены",
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)"
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
"parameterSet": "Параметр {{parameter}} установлен",
"parsingFailed": "Не удалось выполнить синтаксический анализ",
"recallParameter": "Отозвать {{label}}",
"allPrompts": "Все запросы",
"imageDimensions": "Размеры изображения"
},
"queue": {
"status": "Статус",
@ -1588,10 +1654,11 @@
"denoisingStrength": "Шумоподавление",
"refinermodel": "Модель перерисовщик",
"posAestheticScore": "Положительная эстетическая оценка",
"concatPromptStyle": "Объединение запроса и стиля",
"concatPromptStyle": "Связывание запроса и стиля",
"loading": "Загрузка...",
"steps": "Шаги",
"posStylePrompt": "Запрос стиля"
"posStylePrompt": "Запрос стиля",
"freePromptStyle": "Ручной запрос стиля"
},
"invocationCache": {
"useCache": "Использовать кэш",
@ -1678,7 +1745,8 @@
"allLoRAsAdded": "Все LoRA добавлены",
"defaultVAE": "Стандартное VAE",
"incompatibleBaseModel": "Несовместимая базовая модель",
"loraAlreadyAdded": "LoRA уже добавлена"
"loraAlreadyAdded": "LoRA уже добавлена",
"concepts": "Концепты"
},
"app": {
"storeNotInitialized": "Магазин не инициализирован"
@ -1696,7 +1764,7 @@
},
"generation": {
"title": "Генерация",
"conceptsTab": "Концепты",
"conceptsTab": "LoRA",
"modelTab": "Модель"
},
"advanced": {

View File

@ -5,18 +5,55 @@ import openapiTS from 'openapi-typescript';
const OPENAPI_URL = 'http://127.0.0.1:9090/openapi.json';
const OUTPUT_FILE = 'src/services/api/schema.ts';
async function main() {
process.stdout.write(`Generating types "${OPENAPI_URL}" --> "${OUTPUT_FILE}"...`);
const types = await openapiTS(OPENAPI_URL, {
async function generateTypes(schema) {
process.stdout.write(`Generating types ${OUTPUT_FILE}...`);
const types = await openapiTS(schema, {
exportType: true,
transform: (schemaObject) => {
if ('format' in schemaObject && schemaObject.format === 'binary') {
return schemaObject.nullable ? 'Blob | null' : 'Blob';
}
if (schemaObject.title === 'MetadataField') {
// This is `Record<string, never>` by default, but it actually accepts any a dict of any valid JSON value.
return 'Record<string, unknown>';
}
},
});
fs.writeFileSync(OUTPUT_FILE, types);
process.stdout.write(`\nOK!\r\n`);
}
async function main() {
const encoding = 'utf-8';
if (process.stdin.isTTY) {
// Handle generating types with an arg (e.g. URL or path to file)
if (process.argv.length > 3) {
console.error('Usage: typegen.js <openapi.json>');
process.exit(1);
}
if (process.argv[2]) {
const schema = new Buffer.from(process.argv[2], encoding);
generateTypes(schema);
} else {
generateTypes(OPENAPI_URL);
}
} else {
// Handle generating types from stdin
let schema = '';
process.stdin.setEncoding(encoding);
process.stdin.on('readable', function () {
const chunk = process.stdin.read();
if (chunk !== null) {
schema += chunk;
}
});
process.stdin.on('end', function () {
generateTypes(JSON.parse(schema));
});
}
}
main();

View File

@ -55,6 +55,7 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
import type { AppDispatch, RootState } from 'app/store/store';
import { addControlAdapterAutoProcessorUpdateListener } from './listeners/controlAdapterAutoProcessorUpdateListener';
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
export const listenerMiddleware = createListenerMiddleware();
@ -125,6 +126,7 @@ addBulkDownloadListeners(startAppListening);
// ControlNet
addControlNetImageProcessedListener(startAppListening);
addControlNetAutoProcessListener(startAppListening);
addControlAdapterAutoProcessorUpdateListener(startAppListening);
// Boards
addImageAddedToBoardFulfilledListener(startAppListening);

View File

@ -38,7 +38,7 @@ export const addCanvasImageToControlNetListener = (startAppListening: AppStartLi
type: 'image/png',
}),
image_category: 'control',
is_intermediate: false,
is_intermediate: true,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
crop_visible: false,
postUploadAction: {

View File

@ -48,7 +48,7 @@ export const addCanvasMaskToControlNetListener = (startAppListening: AppStartLis
type: 'image/png',
}),
image_category: 'mask',
is_intermediate: false,
is_intermediate: true,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
crop_visible: false,
postUploadAction: {

View File

@ -0,0 +1,77 @@
import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { CONTROLNET_MODEL_DEFAULT_PROCESSORS } from 'features/controlAdapters/store/constants';
import {
controlAdapterAutoConfigToggled,
controlAdapterModelChanged,
controlAdapterProcessortTypeChanged,
selectControlAdapterById,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import type { ControlAdapterProcessorType } from 'features/controlAdapters/store/types';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
import { modelsApi } from 'services/api/endpoints/models';
export const addControlAdapterAutoProcessorUpdateListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: isAnyOf(controlAdapterModelChanged, controlAdapterAutoConfigToggled),
effect: async (action, { getState, dispatch }) => {
let id;
let model;
const state = getState();
const { controlAdapters: controlAdaptersState } = state;
if (controlAdapterModelChanged.match(action)) {
model = action.payload.model;
id = action.payload.id;
const cn = selectControlAdapterById(controlAdaptersState, id);
if (!cn) {
return;
}
if (!isControlNetOrT2IAdapter(cn)) {
return;
}
}
if (controlAdapterAutoConfigToggled.match(action)) {
id = action.payload.id;
const cn = selectControlAdapterById(controlAdaptersState, id);
if (!cn) {
return;
}
if (!isControlNetOrT2IAdapter(cn)) {
return;
}
// if they turned off autoconfig, return
if (!cn.shouldAutoConfig) {
return;
}
model = cn.model;
}
if (!model || !id) {
return;
}
let processorType: ControlAdapterProcessorType | undefined = undefined;
const { data: modelConfig } = modelsApi.endpoints.getModelConfig.select(model.key)(state);
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (modelConfig?.name.includes(modelSubstring)) {
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
dispatch(
controlAdapterProcessortTypeChanged({ id, processorType: processorType || 'none', shouldAutoConfig: true })
);
},
});
};

View File

@ -101,7 +101,7 @@ export const addEnqueueRequestedCanvasListener = (startAppListening: AppStartLis
).unwrap();
}
const graph = buildCanvasGraph(state, generationMode, canvasInitImage, canvasMaskImage);
const graph = await buildCanvasGraph(state, generationMode, canvasInitImage, canvasMaskImage);
log.debug({ graph: parseify(graph) }, `Canvas graph built`);

View File

@ -20,15 +20,15 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
if (model && model.base === 'sdxl') {
if (action.payload.tabName === 'txt2img') {
graph = buildLinearSDXLTextToImageGraph(state);
graph = await buildLinearSDXLTextToImageGraph(state);
} else {
graph = buildLinearSDXLImageToImageGraph(state);
graph = await buildLinearSDXLImageToImageGraph(state);
}
} else {
if (action.payload.tabName === 'txt2img') {
graph = buildLinearTextToImageGraph(state);
graph = await buildLinearTextToImageGraph(state);
} else {
graph = buildLinearImageToImageGraph(state);
graph = await buildLinearImageToImageGraph(state);
}
}

View File

@ -20,7 +20,7 @@ const sx: ChakraProps['sx'] = {
'.react-colorful__hue-pointer': colorPickerPointerStyles,
'.react-colorful__saturation-pointer': colorPickerPointerStyles,
'.react-colorful__alpha-pointer': colorPickerPointerStyles,
gap: 2,
gap: 5,
flexDir: 'column',
};
@ -39,8 +39,8 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
<Flex sx={sx}>
<RgbaColorPicker color={color} onChange={onChange} style={colorPickerStyles} {...rest} />
{withNumberInput && (
<Flex>
<FormControl>
<Flex gap={5}>
<FormControl gap={0}>
<FormLabel>{t('common.red')}</FormLabel>
<CompositeNumberInput
value={color.r}
@ -52,7 +52,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
defaultValue={90}
/>
</FormControl>
<FormControl>
<FormControl gap={0}>
<FormLabel>{t('common.green')}</FormLabel>
<CompositeNumberInput
value={color.g}
@ -64,7 +64,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
defaultValue={90}
/>
</FormControl>
<FormControl>
<FormControl gap={0}>
<FormLabel>{t('common.blue')}</FormLabel>
<CompositeNumberInput
value={color.b}
@ -76,7 +76,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
defaultValue={255}
/>
</FormControl>
<FormControl>
<FormControl gap={0}>
<FormLabel>{t('common.alpha')}</FormLabel>
<CompositeNumberInput
value={color.a}

View File

@ -29,7 +29,7 @@ import { Layer, Stage } from 'react-konva';
import IAICanvasBoundingBoxOverlay from './IAICanvasBoundingBoxOverlay';
import IAICanvasGrid from './IAICanvasGrid';
import IAICanvasIntermediateImage from './IAICanvasIntermediateImage';
import IAICanvasMaskCompositer from './IAICanvasMaskCompositer';
import IAICanvasMaskCompositor from './IAICanvasMaskCompositor';
import IAICanvasMaskLines from './IAICanvasMaskLines';
import IAICanvasObjectRenderer from './IAICanvasObjectRenderer';
import IAICanvasStagingArea from './IAICanvasStagingArea';
@ -176,7 +176,7 @@ const IAICanvas = () => {
</Layer>
<Layer id="mask" visible={isMaskEnabled && !isStaging} listening={false}>
<IAICanvasMaskLines visible={true} listening={false} />
<IAICanvasMaskCompositer listening={false} />
<IAICanvasMaskCompositor listening={false} />
</Layer>
<Layer listening={false}>
<IAICanvasBoundingBoxOverlay />

View File

@ -16,9 +16,9 @@ const canvasMaskCompositerSelector = createMemoizedSelector(selectCanvasSlice, (
};
});
type IAICanvasMaskCompositerProps = RectConfig;
type IAICanvasMaskCompositorProps = RectConfig;
const IAICanvasMaskCompositer = (props: IAICanvasMaskCompositerProps) => {
const IAICanvasMaskCompositor = (props: IAICanvasMaskCompositorProps) => {
const { ...rest } = props;
const { stageCoordinates, stageDimensions } = useAppSelector(canvasMaskCompositerSelector);
@ -89,4 +89,4 @@ const IAICanvasMaskCompositer = (props: IAICanvasMaskCompositerProps) => {
);
};
export default memo(IAICanvasMaskCompositer);
export default memo(IAICanvasMaskCompositor);

View File

@ -5,6 +5,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { stagingAreaImageSaved } from 'features/canvas/store/actions';
import {
commitStagingAreaImage,
discardStagedImage,
discardStagedImages,
nextStagingAreaImage,
prevStagingAreaImage,
@ -22,6 +23,7 @@ import {
PiEyeBold,
PiEyeSlashBold,
PiFloppyDiskBold,
PiTrashSimpleBold,
PiXBold,
} from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
@ -44,6 +46,40 @@ const selector = createMemoizedSelector(selectCanvasSlice, (canvas) => {
};
});
const ClearStagingIntermediatesIconButton = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleDiscardStagingArea = useCallback(() => {
dispatch(discardStagedImages());
}, [dispatch]);
const handleDiscardStagingImage = useCallback(() => {
dispatch(discardStagedImage());
}, [dispatch]);
return (
<>
<IconButton
tooltip={`${t('unifiedCanvas.discardCurrent')}`}
aria-label={t('unifiedCanvas.discardCurrent')}
icon={<PiXBold />}
onClick={handleDiscardStagingImage}
colorScheme="invokeBlue"
fontSize={16}
/>
<IconButton
tooltip={`${t('unifiedCanvas.discardAll')} (Esc)`}
aria-label={t('unifiedCanvas.discardAll')}
icon={<PiTrashSimpleBold />}
onClick={handleDiscardStagingArea}
colorScheme="error"
fontSize={16}
/>
</>
);
};
const IAICanvasStagingAreaToolbar = () => {
const dispatch = useAppDispatch();
const { currentStagingAreaImage, shouldShowStagingImage, currentIndex, total } = useAppSelector(selector);
@ -185,14 +221,7 @@ const IAICanvasStagingAreaToolbar = () => {
onClick={handleSaveToGallery}
colorScheme="invokeBlue"
/>
<IconButton
tooltip={`${t('unifiedCanvas.discardAll')} (Esc)`}
aria-label={t('unifiedCanvas.discardAll')}
icon={<PiXBold />}
onClick={handleDiscardStagingArea}
colorScheme="error"
fontSize={20}
/>
<ClearStagingIntermediatesIconButton />
</ButtonGroup>
</Flex>
);

View File

@ -18,6 +18,7 @@ import {
setShouldAutoSave,
setShouldCropToBoundingBoxOnSave,
setShouldDarkenOutsideBoundingBox,
setShouldInvertBrushSizeScrollDirection,
setShouldRestrictStrokesToBox,
setShouldShowCanvasDebugInfo,
setShouldShowGrid,
@ -40,6 +41,7 @@ const IAICanvasSettingsButtonPopover = () => {
const shouldAutoSave = useAppSelector((s) => s.canvas.shouldAutoSave);
const shouldCropToBoundingBoxOnSave = useAppSelector((s) => s.canvas.shouldCropToBoundingBoxOnSave);
const shouldDarkenOutsideBoundingBox = useAppSelector((s) => s.canvas.shouldDarkenOutsideBoundingBox);
const shouldInvertBrushSizeScrollDirection = useAppSelector((s) => s.canvas.shouldInvertBrushSizeScrollDirection);
const shouldShowCanvasDebugInfo = useAppSelector((s) => s.canvas.shouldShowCanvasDebugInfo);
const shouldShowGrid = useAppSelector((s) => s.canvas.shouldShowGrid);
const shouldShowIntermediates = useAppSelector((s) => s.canvas.shouldShowIntermediates);
@ -76,6 +78,10 @@ const IAICanvasSettingsButtonPopover = () => {
(e: ChangeEvent<HTMLInputElement>) => dispatch(setShouldDarkenOutsideBoundingBox(e.target.checked)),
[dispatch]
);
const handleChangeShouldInvertBrushSizeScrollDirection = useCallback(
(e: ChangeEvent<HTMLInputElement>) => dispatch(setShouldInvertBrushSizeScrollDirection(e.target.checked)),
[dispatch]
);
const handleChangeShouldAutoSave = useCallback(
(e: ChangeEvent<HTMLInputElement>) => dispatch(setShouldAutoSave(e.target.checked)),
[dispatch]
@ -144,6 +150,13 @@ const IAICanvasSettingsButtonPopover = () => {
<FormLabel>{t('unifiedCanvas.limitStrokesToBox')}</FormLabel>
<Checkbox isChecked={shouldRestrictStrokesToBox} onChange={handleChangeShouldRestrictStrokesToBox} />
</FormControl>
<FormControl>
<FormLabel>{t('unifiedCanvas.invertBrushSizeScrollDirection')}</FormLabel>
<Checkbox
isChecked={shouldInvertBrushSizeScrollDirection}
onChange={handleChangeShouldInvertBrushSizeScrollDirection}
/>
</FormControl>
<FormControl>
<FormLabel>{t('unifiedCanvas.showCanvasDebugInfo')}</FormLabel>
<Checkbox isChecked={shouldShowCanvasDebugInfo} onChange={handleChangeShouldShowCanvasDebugInfo} />

View File

@ -15,6 +15,7 @@ const useCanvasWheel = (stageRef: MutableRefObject<Konva.Stage | null>) => {
const stageScale = useAppSelector((s) => s.canvas.stageScale);
const isMoveStageKeyHeld = useStore($isMoveStageKeyHeld);
const brushSize = useAppSelector((s) => s.canvas.brushSize);
const shouldInvertBrushSizeScrollDirection = useAppSelector((s) => s.canvas.shouldInvertBrushSizeScrollDirection);
return useCallback(
(e: KonvaEventObject<WheelEvent>) => {
@ -28,10 +29,16 @@ const useCanvasWheel = (stageRef: MutableRefObject<Konva.Stage | null>) => {
// checking for ctrl key is pressed or not,
// so that brush size can be controlled using ctrl + scroll up/down
// Invert the delta if the property is set to true
let delta = e.evt.deltaY;
if (shouldInvertBrushSizeScrollDirection) {
delta = -delta;
}
if ($ctrl.get() || $meta.get()) {
// This equation was derived by fitting a curve to the desired brush sizes and deltas
// see https://github.com/invoke-ai/InvokeAI/pull/5542#issuecomment-1915847565
const targetDelta = Math.sign(e.evt.deltaY) * 0.7363 * Math.pow(1.0394, brushSize);
const targetDelta = Math.sign(delta) * 0.7363 * Math.pow(1.0394, brushSize);
// This needs to be clamped to prevent the delta from getting too large
const finalDelta = clamp(targetDelta, -20, 20);
// The new brush size is also clamped to prevent it from getting too large or small
@ -67,7 +74,7 @@ const useCanvasWheel = (stageRef: MutableRefObject<Konva.Stage | null>) => {
dispatch(setStageCoordinates(newCoordinates));
}
},
[stageRef, isMoveStageKeyHeld, stageScale, dispatch, brushSize]
[stageRef, isMoveStageKeyHeld, brushSize, dispatch, stageScale, shouldInvertBrushSizeScrollDirection]
);
};

View File

@ -65,6 +65,7 @@ const initialCanvasState: CanvasState = {
shouldAutoSave: false,
shouldCropToBoundingBoxOnSave: false,
shouldDarkenOutsideBoundingBox: false,
shouldInvertBrushSizeScrollDirection: false,
shouldLockBoundingBox: false,
shouldPreserveMaskedArea: false,
shouldRestrictStrokesToBox: true,
@ -220,6 +221,9 @@ export const canvasSlice = createSlice({
setShouldDarkenOutsideBoundingBox: (state, action: PayloadAction<boolean>) => {
state.shouldDarkenOutsideBoundingBox = action.payload;
},
setShouldInvertBrushSizeScrollDirection: (state, action: PayloadAction<boolean>) => {
state.shouldInvertBrushSizeScrollDirection = action.payload;
},
clearCanvasHistory: (state) => {
state.pastLayerStates = [];
state.futureLayerStates = [];
@ -288,6 +292,31 @@ export const canvasSlice = createSlice({
state.shouldShowStagingImage = true;
state.batchIds = [];
},
discardStagedImage: (state) => {
const { images, selectedImageIndex } = state.layerState.stagingArea;
state.pastLayerStates.push(cloneDeep(state.layerState));
if (state.pastLayerStates.length > MAX_HISTORY) {
state.pastLayerStates.shift();
}
if (!images.length) {
return;
}
images.splice(selectedImageIndex, 1);
if (selectedImageIndex >= images.length) {
state.layerState.stagingArea.selectedImageIndex = images.length - 1;
}
if (!images.length) {
state.shouldShowStagingImage = false;
state.shouldShowStagingOutline = false;
}
state.futureLayerStates = [];
},
addFillRect: (state) => {
const { boundingBoxCoordinates, boundingBoxDimensions, brushColor } = state;
@ -655,6 +684,7 @@ export const {
commitColorPickerColor,
commitStagingAreaImage,
discardStagedImages,
discardStagedImage,
nextStagingAreaImage,
prevStagingAreaImage,
redo,
@ -674,6 +704,7 @@ export const {
setShouldAutoSave,
setShouldCropToBoundingBoxOnSave,
setShouldDarkenOutsideBoundingBox,
setShouldInvertBrushSizeScrollDirection,
setShouldPreserveMaskedArea,
setShouldShowBoundingBox,
setShouldShowCanvasDebugInfo,

View File

@ -120,6 +120,7 @@ export interface CanvasState {
shouldAutoSave: boolean;
shouldCropToBoundingBoxOnSave: boolean;
shouldDarkenOutsideBoundingBox: boolean;
shouldInvertBrushSizeScrollDirection: boolean;
shouldLockBoundingBox: boolean;
shouldPreserveMaskedArea: boolean;
shouldRestrictStrokesToBox: boolean;

View File

@ -1,6 +1,12 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { CONTROLNET_MODEL_DEFAULT_PROCESSORS, CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import { controlAdapterAdded } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
import type {
ControlAdapterProcessorType,
ControlAdapterType,
RequiredControlAdapterProcessorNode,
} from 'features/controlAdapters/store/types';
import { cloneDeep } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useControlAdapterModels } from './useControlAdapterModels';
@ -22,6 +28,29 @@ export const useAddControlAdapter = (type: ControlAdapterType) => {
return models[0];
}, [baseModel, models]);
const processor = useMemo(() => {
let processorType;
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (firstModel?.name.includes(modelSubstring)) {
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (!processorType) {
processorType = 'none';
}
const processorNode =
processorType === 'none'
? (cloneDeep(CONTROLNET_PROCESSORS.none.default) as RequiredControlAdapterProcessorNode)
: (cloneDeep(
CONTROLNET_PROCESSORS[processorType as ControlAdapterProcessorType].default
) as RequiredControlAdapterProcessorNode);
return { processorType, processorNode };
}, [firstModel]);
const isDisabled = useMemo(() => !firstModel, [firstModel]);
const addControlAdapter = useCallback(() => {
@ -31,10 +60,10 @@ export const useAddControlAdapter = (type: ControlAdapterType) => {
dispatch(
controlAdapterAdded({
type,
overrides: { model: firstModel },
overrides: { model: firstModel, ...processor },
})
);
}, [dispatch, firstModel, isDisabled, type]);
}, [dispatch, firstModel, isDisabled, type, processor]);
return [addControlAdapter, isDisabled] as const;
};

View File

@ -13,10 +13,7 @@ import { socketInvocationError } from 'services/events/actions';
import { v4 as uuidv4 } from 'uuid';
import { controlAdapterImageProcessed } from './actions';
import {
CONTROLNET_MODEL_DEFAULT_PROCESSORS as CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS,
CONTROLNET_PROCESSORS,
} from './constants';
import { CONTROLNET_PROCESSORS } from './constants';
import type {
ControlAdapterConfig,
ControlAdapterProcessorType,
@ -215,25 +212,6 @@ export const controlAdaptersSlice = createSlice({
update.changes.processedControlImage = null;
let processorType: ControlAdapterProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
// TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType
if (model.key.includes(modelSubstring)) {
processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) {
update.changes.processorType = processorType;
update.changes.processorNode = CONTROLNET_PROCESSORS[processorType]
.default as RequiredControlAdapterProcessorNode;
} else {
update.changes.processorType = 'none';
update.changes.processorNode = CONTROLNET_PROCESSORS.none.default as RequiredControlAdapterProcessorNode;
}
caAdapter.updateOne(state, update);
},
controlAdapterWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => {
@ -298,17 +276,19 @@ export const controlAdaptersSlice = createSlice({
action: PayloadAction<{
id: string;
processorType: ControlAdapterProcessorType;
shouldAutoConfig?: boolean;
}>
) => {
const { id, processorType } = action.payload;
const { id, processorType, shouldAutoConfig = false } = action.payload;
const cn = selectControlAdapterById(state, id);
if (!cn || !isControlNetOrT2IAdapter(cn)) {
return;
}
const processorNode = cloneDeep(
CONTROLNET_PROCESSORS[processorType].default
) as RequiredControlAdapterProcessorNode;
const processorNode =
processorType === 'none'
? (cloneDeep(CONTROLNET_PROCESSORS.none.default) as RequiredControlAdapterProcessorNode)
: (cloneDeep(CONTROLNET_PROCESSORS[processorType].default) as RequiredControlAdapterProcessorNode);
caAdapter.updateOne(state, {
id,
@ -316,7 +296,7 @@ export const controlAdaptersSlice = createSlice({
processorType,
processedControlImage: null,
processorNode,
shouldAutoConfig: false,
shouldAutoConfig,
},
});
},
@ -337,28 +317,6 @@ export const controlAdaptersSlice = createSlice({
changes: { shouldAutoConfig: !cn.shouldAutoConfig },
};
if (update.changes.shouldAutoConfig) {
// manage the processor for the user
let processorType: ControlAdapterProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
// TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType
if (cn.model?.key.includes(modelSubstring)) {
processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) {
update.changes.processorType = processorType;
update.changes.processorNode = CONTROLNET_PROCESSORS[processorType]
.default as RequiredControlAdapterProcessorNode;
} else {
update.changes.processorType = 'none';
update.changes.processorNode = CONTROLNET_PROCESSORS.none.default as RequiredControlAdapterProcessorNode;
}
}
caAdapter.updateOne(state, update);
},
controlAdaptersReset: () => {

View File

@ -6,7 +6,7 @@ const AutoAddIcon = () => {
const { t } = useTranslation();
return (
<Flex position="absolute" insetInlineEnd={0} top={0} p={1}>
<Badge variant="solid" bg="invokeBlue.500">
<Badge variant="solid" bg="invokeBlue.400">
{t('common.auto')}
</Badge>
</Flex>

View File

@ -173,8 +173,8 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
w="full"
maxW="full"
borderBottomRadius="base"
bg={isSelected ? 'invokeBlue.500' : 'base.600'}
color={isSelected ? 'base.50' : 'base.100'}
bg={isSelected ? 'invokeBlue.400' : 'base.600'}
color={isSelected ? 'base.800' : 'base.100'}
lineHeight="short"
fontSize="xs"
>
@ -193,6 +193,7 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
overflow="hidden"
textOverflow="ellipsis"
noOfLines={1}
color="inherit"
/>
<EditableInput sx={editableInputStyles} />
</Editable>

View File

@ -109,8 +109,8 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
w="full"
maxW="full"
borderBottomRadius="base"
bg={isSelected ? 'invokeBlue.500' : 'base.600'}
color={isSelected ? 'base.50' : 'base.100'}
bg={isSelected ? 'invokeBlue.400' : 'base.600'}
color={isSelected ? 'base.800' : 'base.100'}
lineHeight="short"
fontSize="xs"
fontWeight={isSelected ? 'bold' : 'normal'}

View File

@ -15,7 +15,7 @@ export const MetadataItemView = memo(
return (
<Flex gap={2}>
{onRecall && <RecallButton label={label} onClick={onRecall} isDisabled={isDisabled} />}
<Flex direction={direction}>
<Flex direction={direction} fontSize="sm">
<Text fontWeight="semibold" whiteSpace="pre-wrap" pr={2}>
{label}:
</Text>

View File

@ -0,0 +1,27 @@
import { Box, Image } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
type Props = {
image_url?: string;
};
const ModelImage = ({ image_url }: Props) => {
if (!image_url) {
return <Box height="50px" minWidth="50px" />;
}
return (
<Image
src={image_url}
objectFit="cover"
objectPosition="50% 50%"
height="50px"
width="50px"
minHeight="50px"
minWidth="50px"
borderRadius="base"
/>
);
};
export default typedMemo(ModelImage);

View File

@ -22,6 +22,8 @@ import { PiTrashSimpleBold } from 'react-icons/pi';
import { useDeleteModelsMutation } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import ModelImage from './ModelImage';
type ModelListItemProps = {
model: AnyModelConfig;
};
@ -73,6 +75,7 @@ const ModelListItem = (props: ModelListItemProps) => {
return (
<Flex gap={2} alignItems="center" w="full">
<ModelImage image_url={model.cover_image || ''} />
<Flex
as={Button}
isChecked={isSelected}

View File

@ -0,0 +1,134 @@
import { Box, Button, IconButton, Image } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { typedMemo } from 'common/util/typedMemo';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCallback, useState } from 'react';
import { useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold, PiUploadSimpleBold } from 'react-icons/pi';
import { useDeleteModelImageMutation, useUpdateModelImageMutation } from 'services/api/endpoints/models';
type Props = {
model_key: string | null;
model_image?: string | null;
};
const ModelImageUpload = ({ model_key, model_image }: Props) => {
const dispatch = useAppDispatch();
const [image, setImage] = useState<string | null>(model_image || null);
const { t } = useTranslation();
const [updateModelImage] = useUpdateModelImageMutation();
const [deleteModelImage] = useDeleteModelImageMutation();
const onDropAccepted = useCallback(
(files: File[]) => {
const file = files[0];
if (!file || !model_key) {
return;
}
updateModelImage({ key: model_key, image: file })
.unwrap()
.then(() => {
setImage(URL.createObjectURL(file));
dispatch(
addToast(
makeToast({
title: t('modelManager.modelImageUpdated'),
status: 'success',
})
)
);
})
.catch((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelImageUpdateFailed'),
status: 'error',
})
)
);
});
},
[dispatch, model_key, t, updateModelImage]
);
const handleResetImage = useCallback(() => {
if (!model_key) {
return;
}
setImage(null);
deleteModelImage(model_key)
.unwrap()
.then(() => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelImageDeleted'),
status: 'success',
})
)
);
})
.catch((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelImageDeleteFailed'),
status: 'error',
})
)
);
});
}, [dispatch, model_key, t, deleteModelImage]);
const { getInputProps, getRootProps } = useDropzone({
accept: { 'image/png': ['.png'], 'image/jpeg': ['.jpg', '.jpeg', '.png'] },
onDropAccepted,
noDrag: true,
multiple: false,
});
if (image) {
return (
<Box position="relative">
<Image
src={image}
objectFit="cover"
objectPosition="50% 50%"
height="100px"
width="100px"
minWidth="100px"
borderRadius="base"
/>
<IconButton
position="absolute"
top="1"
right="1"
onClick={handleResetImage}
aria-label={t('modelManager.deleteModelImage')}
tooltip={t('modelManager.deleteModelImage')}
icon={<PiArrowCounterClockwiseBold size={16} />}
size="sm"
variant="link"
_hover={{ color: 'base.100' }}
/>
</Box>
);
}
return (
<>
<Button leftIcon={<PiUploadSimpleBold />} {...getRootProps()} pointerEvents="auto">
{t('modelManager.uploadImage')}
</Button>
<input {...getInputProps()} />
</>
);
};
export default typedMemo(ModelImageUpload);

View File

@ -4,6 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
import { useTranslation } from 'react-i18next';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import ModelImageUpload from './Fields/ModelImageUpload';
import { ModelMetadata } from './Metadata/ModelMetadata';
import { ModelAttrView } from './ModelAttrView';
import { ModelEdit } from './ModelEdit';
@ -25,19 +26,22 @@ export const Model = () => {
return (
<>
<Flex flexDir="column" gap={1} p={2}>
<Heading as="h2" fontSize="lg">
{data.name}
</Heading>
<Flex alignItems="center" justifyContent="space-between" gap="4" paddingRight="5">
<Flex flexDir="column" gap={1} p={2}>
<Heading as="h2" fontSize="lg">
{data.name}
</Heading>
{data.source && (
<Text variant="subtext">
{t('modelManager.source')}: {data?.source}
</Text>
)}
<Box mt="4">
<ModelAttrView label="Description" value={data.description} />
</Box>
{data.source && (
<Text variant="subtext">
{t('modelManager.source')}: {data?.source}
</Text>
)}
<Box mt="4">
<ModelAttrView label="Description" value={data.description} />
</Box>
</Flex>
<ModelImageUpload model_key={selectedModelKey} model_image={data.cover_image} />
</Flex>
<Tabs mt="4" h="100%">

View File

@ -129,7 +129,7 @@ export const ModelEdit = () => {
</Flex>
<Flex flexDir="column" gap={3} mt="4">
<Flex>
<Flex gap="4" alignItems="center">
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.description')}</FormLabel>
<Textarea fontSize="md" resize="none" {...register('description')} />
@ -145,20 +145,32 @@ export const ModelEdit = () => {
</FormControl>
</Flex>
{data.type === 'main' && data.format === 'checkpoint' && (
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.variant')}</FormLabel>
<ModelVariantSelect control={control} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
<PredictionTypeSelect control={control} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
<Checkbox {...register('upcast_attention')} />
</FormControl>
</Flex>
<>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
<Input
{...register('config_path', {
validate: (value) => (value && value.trim().length > 3) || 'Must be at least 3 characters',
})}
/>
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.variant')}</FormLabel>
<ModelVariantSelect control={control} />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
<PredictionTypeSelect control={control} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
<Checkbox {...register('upcast_attention')} />
</FormControl>
</Flex>
</>
)}
</Flex>
</form>

View File

@ -90,21 +90,26 @@ export const ModelView = () => {
<ModelAttrView label={t('common.format')} value={modelData.format} />
<ModelAttrView label={t('modelManager.path')} value={modelData.path} />
</Flex>
{modelData.type === 'main' && (
{modelData.type === 'main' && modelData.format === 'diffusers' && modelData.repo_variant && (
<Flex gap={2}>
{modelData.format === 'diffusers' && modelData.repo_variant && (
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
)}
{modelData.format === 'checkpoint' && (
<>
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config_path} />
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
</>
)}
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
</Flex>
)}
{modelData.type === 'main' && modelData.format === 'checkpoint' && (
<>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config_path} />
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
</Flex>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
</Flex>
</>
)}
{modelData.type === 'ip_adapter' && (
<Flex gap={2}>
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={modelData.image_encoder_model_id} />
@ -112,9 +117,11 @@ export const ModelView = () => {
)}
</Flex>
</Box>
<Box layerStyle="second" borderRadius="base" p={3}>
<DefaultSettings />
</Box>
{modelData.type === 'main' && (
<Box layerStyle="second" borderRadius="base" p={3}>
<DefaultSettings />
</Box>
)}
</Flex>
);
};

View File

@ -0,0 +1,55 @@
import { Button, Flex, Image, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { workflowModeChanged } from 'features/nodes/store/workflowSlice';
import InvokeLogoSVG from 'public/assets/images/invoke-symbol-wht-lrg.svg';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export const EmptyState = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const onClick = useCallback(() => {
dispatch(workflowModeChanged('edit'));
}, [dispatch]);
return (
<Flex
sx={{
w: 'full',
h: 'full',
userSelect: 'none',
}}
>
<Flex
sx={{
alignItems: 'center',
justifyContent: 'center',
borderRadius: 'base',
flexDir: 'column',
gap: 5,
maxW: '230px',
margin: '0 auto',
}}
>
<Image
src={InvokeLogoSVG}
alt="invoke-ai-logo"
opacity={0.2}
mixBlendMode="overlay"
w={16}
h={16}
minW={16}
minH={16}
userSelect="none"
/>
<Text textAlign="center" fontSize="md">
{t('nodes.noFieldsViewMode')}
</Text>
<Button colorScheme="invokeBlue" onClick={onClick}>
{t('nodes.edit')}
</Button>
</Flex>
</Flex>
);
};

View File

@ -7,6 +7,7 @@ import { selectWorkflowSlice } from 'features/nodes/store/workflowSlice';
import { t } from 'i18next';
import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
import { EmptyState } from './EmptyState';
import WorkflowField from './WorkflowField';
const selector = createMemoizedSelector(selectWorkflowSlice, (workflow) => {
@ -30,7 +31,7 @@ export const WorkflowViewMode = () => {
<WorkflowField key={`${nodeId}.${fieldName}`} nodeId={nodeId} fieldName={fieldName} />
))
) : (
<IAINoContentFallback label={t('nodes.noFieldsLinearview')} icon={null} />
<EmptyState />
)}
</Flex>
</ScrollableContent>

View File

@ -12,17 +12,17 @@ const WorkflowPanel = () => {
<Flex layerStyle="first" flexDir="column" w="full" h="full" borderRadius="base" p={2} gap={2}>
<Tabs variant="line" display="flex" w="full" h="full" flexDir="column">
<TabList>
<Tab>{t('common.details')}</Tab>
<Tab>{t('common.linear')}</Tab>
<Tab>{t('common.details')}</Tab>
<Tab>JSON</Tab>
</TabList>
<TabPanels>
<TabPanel>
<WorkflowGeneralTab />
<WorkflowLinearTab />
</TabPanel>
<TabPanel>
<WorkflowLinearTab />
<WorkflowGeneralTab />
</TabPanel>
<TabPanel>
<WorkflowJSONTab />

View File

@ -55,8 +55,22 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
// #region Model-related schemas
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
const zSubModelType = z.enum([
'unet',
'text_encoder',
'text_encoder_2',
'tokenizer',
'tokenizer_2',
'vae',
'vae_decoder',
'vae_encoder',
'scheduler',
'safety_checker',
]);
const zModelIdentifier = z.object({
key: z.string().min(1),
submodel_type: zSubModelType.nullish(),
});
export const isModelIdentifier = (field: unknown): field is ModelIdentifier =>
zModelIdentifier.safeParse(field).success;

View File

@ -1,18 +1,22 @@
import type { RootState } from 'app/store/store';
import { selectValidControlNets } from 'features/controlAdapters/store/controlAdaptersSlice';
import { omit } from 'lodash-es';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import type {
CollectInvocation,
ControlField,
ControlNetInvocation,
CoreMetadataInvocation,
NonNullableGraph,
} from 'services/api/types';
import { isControlNetModelConfig } from 'services/api/types';
import { CONTROL_NET_COLLECT } from './constants';
import { upsertMetadata } from './metadata';
import { getModelMetadataField, upsertMetadata } from './metadata';
export const addControlNetToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
export const addControlNetToLinearGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): Promise<void> => {
const validControlNets = selectValidControlNets(state.controlAdapters).filter(
(ca) => ca.model?.base === state.generation.model?.base
);
@ -39,7 +43,7 @@ export const addControlNetToLinearGraph = (state: RootState, graph: NonNullableG
},
});
validControlNets.forEach((controlNet) => {
validControlNets.forEach(async (controlNet) => {
if (!controlNet.model) {
return;
}
@ -85,7 +89,17 @@ export const addControlNetToLinearGraph = (state: RootState, graph: NonNullableG
graph.nodes[controlNetNode.id] = controlNetNode as ControlNetInvocation;
controlNetMetadata.push(omit(controlNetNode, ['id', 'type', 'is_intermediate']) as ControlField);
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isControlNetModelConfig);
controlNetMetadata.push({
control_model: getModelMetadataField(modelConfig),
control_weight: weight,
control_mode: controlMode,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
resize_mode: resizeMode,
image: controlNetNode.image,
});
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },

View File

@ -1,18 +1,22 @@
import type { RootState } from 'app/store/store';
import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
import { omit } from 'lodash-es';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import type {
CollectInvocation,
CoreMetadataInvocation,
IPAdapterInvocation,
IPAdapterMetadataField,
NonNullableGraph,
} from 'services/api/types';
import { isIPAdapterModelConfig } from 'services/api/types';
import { IP_ADAPTER_COLLECT } from './constants';
import { upsertMetadata } from './metadata';
import { getModelMetadataField, upsertMetadata } from './metadata';
export const addIPAdapterToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
export const addIPAdapterToLinearGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): Promise<void> => {
const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter(
(ca) => ca.model?.base === state.generation.model?.base
);
@ -35,7 +39,7 @@ export const addIPAdapterToLinearGraph = (state: RootState, graph: NonNullableGr
const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = [];
validIPAdapters.forEach((ipAdapter) => {
validIPAdapters.forEach(async (ipAdapter) => {
if (!ipAdapter.model) {
return;
}
@ -58,9 +62,17 @@ export const addIPAdapterToLinearGraph = (state: RootState, graph: NonNullableGr
return;
}
graph.nodes[ipAdapterNode.id] = ipAdapterNode as IPAdapterInvocation;
graph.nodes[ipAdapterNode.id] = ipAdapterNode;
ipAdapterMetdata.push(omit(ipAdapterNode, ['id', 'type', 'is_intermediate']) as IPAdapterMetadataField);
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isIPAdapterModelConfig);
ipAdapterMetdata.push({
weight: weight,
ip_adapter_model: getModelMetadataField(modelConfig),
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
image: ipAdapterNode.image,
});
graph.edges.push({
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },

View File

@ -1,16 +1,22 @@
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { filter, size } from 'lodash-es';
import type { CoreMetadataInvocation, LoraLoaderInvocation, NonNullableGraph } from 'services/api/types';
import {
type CoreMetadataInvocation,
isLoRAModelConfig,
type LoRALoaderInvocation,
type NonNullableGraph,
} from 'services/api/types';
import { CLIP_SKIP, LORA_LOADER, MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, POSITIVE_CONDITIONING } from './constants';
import { upsertMetadata } from './metadata';
import { getModelMetadataField, upsertMetadata } from './metadata';
export const addLoRAsToGraph = (
export const addLoRAsToGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string,
modelLoaderNodeId: string = MAIN_MODEL_LOADER
): void => {
): Promise<void> => {
/**
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
@ -39,12 +45,12 @@ export const addLoRAsToGraph = (
let currentLoraIndex = 0;
const loraMetadata: CoreMetadataInvocation['loras'] = [];
enabledLoRAs.forEach((lora) => {
enabledLoRAs.forEach(async (lora) => {
const { weight } = lora;
const { key } = lora.model;
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
const loraLoaderNode: LoraLoaderInvocation = {
const loraLoaderNode: LoRALoaderInvocation = {
type: 'lora_loader',
id: currentLoraNodeId,
is_intermediate: true,
@ -52,8 +58,10 @@ export const addLoRAsToGraph = (
weight,
};
const modelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
loraMetadata.push({
model: { key },
model: getModelMetadataField(modelConfig),
weight,
});

View File

@ -1,6 +1,12 @@
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { filter, size } from 'lodash-es';
import type { CoreMetadataInvocation, NonNullableGraph, SDXLLoraLoaderInvocation } from 'services/api/types';
import {
type CoreMetadataInvocation,
isLoRAModelConfig,
type NonNullableGraph,
type SDXLLoRALoaderInvocation,
} from 'services/api/types';
import {
LORA_LOADER,
@ -10,14 +16,14 @@ import {
SDXL_REFINER_INPAINT_CREATE_MASK,
SEAMLESS,
} from './constants';
import { upsertMetadata } from './metadata';
import { getModelMetadataField, upsertMetadata } from './metadata';
export const addSDXLLoRAsToGraph = (
export const addSDXLLoRAsToGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string,
modelLoaderNodeId: string = SDXL_MODEL_LOADER
): void => {
): Promise<void> => {
/**
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
@ -55,12 +61,12 @@ export const addSDXLLoRAsToGraph = (
let lastLoraNodeId = '';
let currentLoraIndex = 0;
enabledLoRAs.forEach((lora) => {
enabledLoRAs.forEach(async (lora) => {
const { weight } = lora;
const { key } = lora.model;
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
const loraLoaderNode: SDXLLoraLoaderInvocation = {
const loraLoaderNode: SDXLLoRALoaderInvocation = {
type: 'sdxl_lora_loader',
id: currentLoraNodeId,
is_intermediate: true,
@ -68,7 +74,9 @@ export const addSDXLLoRAsToGraph = (
weight,
};
loraMetadata.push({ model: { key }, weight });
const modelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
loraMetadata.push({ model: getModelMetadataField(modelConfig), weight });
// add to graph
graph.nodes[currentLoraNodeId] = loraLoaderNode;

View File

@ -1,9 +1,11 @@
import type { RootState } from 'app/store/store';
import type {
CreateDenoiseMaskInvocation,
ImageDTO,
NonNullableGraph,
SeamlessModeInvocation,
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import {
type CreateDenoiseMaskInvocation,
type ImageDTO,
isRefinerMainModelModelConfig,
type NonNullableGraph,
type SeamlessModeInvocation,
} from 'services/api/types';
import {
@ -25,16 +27,16 @@ import {
SDXL_REFINER_SEAMLESS,
} from './constants';
import { getSDXLStylePrompts } from './graphBuilderUtils';
import { upsertMetadata } from './metadata';
import { getModelMetadataField, upsertMetadata } from './metadata';
export const addSDXLRefinerToGraph = (
export const addSDXLRefinerToGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string,
modelLoaderNodeId?: string,
canvasInitImage?: ImageDTO,
canvasMaskImage?: ImageDTO
): void => {
): Promise<void> => {
const {
refinerModel,
refinerPositiveAestheticScore,
@ -55,9 +57,10 @@ export const addSDXLRefinerToGraph = (
const fp32 = vaePrecision === 'fp32';
const isUsingScaledDimensions = ['auto', 'manual'].includes(boundingBoxScaleMethod);
const modelConfig = await fetchModelConfigWithTypeGuard(refinerModel.key, isRefinerMainModelModelConfig);
upsertMetadata(graph, {
refiner_model: refinerModel,
refiner_model: getModelMetadataField(modelConfig),
refiner_positive_aesthetic_score: refinerPositiveAestheticScore,
refiner_negative_aesthetic_score: refinerNegativeAestheticScore,
refiner_cfg_scale: refinerCFGScale,

View File

@ -1,18 +1,22 @@
import type { RootState } from 'app/store/store';
import { selectValidT2IAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
import { omit } from 'lodash-es';
import type {
CollectInvocation,
CoreMetadataInvocation,
NonNullableGraph,
T2IAdapterField,
T2IAdapterInvocation,
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import {
type CollectInvocation,
type CoreMetadataInvocation,
isT2IAdapterModelConfig,
type NonNullableGraph,
type T2IAdapterInvocation,
} from 'services/api/types';
import { T2I_ADAPTER_COLLECT } from './constants';
import { upsertMetadata } from './metadata';
import { getModelMetadataField, upsertMetadata } from './metadata';
export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
export const addT2IAdaptersToLinearGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): Promise<void> => {
const validT2IAdapters = selectValidT2IAdapters(state.controlAdapters).filter(
(ca) => ca.model?.base === state.generation.model?.base
);
@ -33,9 +37,9 @@ export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullable
},
});
const t2iAdapterMetdata: CoreMetadataInvocation['t2iAdapters'] = [];
const t2iAdapterMetadata: CoreMetadataInvocation['t2iAdapters'] = [];
validT2IAdapters.forEach((t2iAdapter) => {
validT2IAdapters.forEach(async (t2iAdapter) => {
if (!t2iAdapter.model) {
return;
}
@ -77,9 +81,18 @@ export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullable
return;
}
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode as T2IAdapterInvocation;
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode;
t2iAdapterMetdata.push(omit(t2iAdapterNode, ['id', 'type', 'is_intermediate']) as T2IAdapterField);
const modelConfig = await fetchModelConfigWithTypeGuard(t2iAdapter.model.key, isT2IAdapterModelConfig);
t2iAdapterMetadata.push({
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
resize_mode: resizeMode,
t2i_adapter_model: getModelMetadataField(modelConfig),
weight: weight,
image: t2iAdapterNode.image,
});
graph.edges.push({
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },
@ -90,6 +103,6 @@ export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullable
});
});
upsertMetadata(graph, { t2iAdapters: t2iAdapterMetdata });
upsertMetadata(graph, { t2iAdapters: t2iAdapterMetadata });
}
};

View File

@ -1,5 +1,7 @@
import type { RootState } from 'app/store/store';
import type { NonNullableGraph } from 'services/api/types';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import type { ModelMetadataField, NonNullableGraph } from 'services/api/types';
import { isVAEModelConfig } from 'services/api/types';
import {
CANVAS_IMAGE_TO_IMAGE_GRAPH,
@ -23,13 +25,13 @@ import {
TEXT_TO_IMAGE_GRAPH,
VAE_LOADER,
} from './constants';
import { upsertMetadata } from './metadata';
import { getModelMetadataField, upsertMetadata } from './metadata';
export const addVAEToGraph = (
export const addVAEToGraph = async (
state: RootState,
graph: NonNullableGraph,
modelLoaderNodeId: string = MAIN_MODEL_LOADER
): void => {
): Promise<void> => {
const { vae, seamlessXAxis, seamlessYAxis } = state.generation;
const { boundingBoxScaleMethod } = state.canvas;
const { refinerModel } = state.sdxl;
@ -149,6 +151,8 @@ export const addVAEToGraph = (
}
if (vae) {
upsertMetadata(graph, { vae });
const modelConfig = await fetchModelConfigWithTypeGuard(vae.key, isVAEModelConfig);
const vaeMetadata: ModelMetadataField = getModelMetadataField(modelConfig);
upsertMetadata(graph, { vae: vaeMetadata });
}
};

View File

@ -10,46 +10,46 @@ import { buildCanvasSDXLOutpaintGraph } from './buildCanvasSDXLOutpaintGraph';
import { buildCanvasSDXLTextToImageGraph } from './buildCanvasSDXLTextToImageGraph';
import { buildCanvasTextToImageGraph } from './buildCanvasTextToImageGraph';
export const buildCanvasGraph = (
export const buildCanvasGraph = async (
state: RootState,
generationMode: 'txt2img' | 'img2img' | 'inpaint' | 'outpaint',
canvasInitImage: ImageDTO | undefined,
canvasMaskImage: ImageDTO | undefined
) => {
): Promise<NonNullableGraph> => {
let graph: NonNullableGraph;
if (generationMode === 'txt2img') {
if (state.generation.model && state.generation.model.base === 'sdxl') {
graph = buildCanvasSDXLTextToImageGraph(state);
graph = await buildCanvasSDXLTextToImageGraph(state);
} else {
graph = buildCanvasTextToImageGraph(state);
graph = await buildCanvasTextToImageGraph(state);
}
} else if (generationMode === 'img2img') {
if (!canvasInitImage) {
throw new Error('Missing canvas init image');
}
if (state.generation.model && state.generation.model.base === 'sdxl') {
graph = buildCanvasSDXLImageToImageGraph(state, canvasInitImage);
graph = await buildCanvasSDXLImageToImageGraph(state, canvasInitImage);
} else {
graph = buildCanvasImageToImageGraph(state, canvasInitImage);
graph = await buildCanvasImageToImageGraph(state, canvasInitImage);
}
} else if (generationMode === 'inpaint') {
if (!canvasInitImage || !canvasMaskImage) {
throw new Error('Missing canvas init and mask images');
}
if (state.generation.model && state.generation.model.base === 'sdxl') {
graph = buildCanvasSDXLInpaintGraph(state, canvasInitImage, canvasMaskImage);
graph = await buildCanvasSDXLInpaintGraph(state, canvasInitImage, canvasMaskImage);
} else {
graph = buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage);
graph = await buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage);
}
} else {
if (!canvasInitImage) {
throw new Error('Missing canvas init image');
}
if (state.generation.model && state.generation.model.base === 'sdxl') {
graph = buildCanvasSDXLOutpaintGraph(state, canvasInitImage, canvasMaskImage);
graph = await buildCanvasSDXLOutpaintGraph(state, canvasInitImage, canvasMaskImage);
} else {
graph = buildCanvasOutpaintGraph(state, canvasInitImage, canvasMaskImage);
graph = await buildCanvasOutpaintGraph(state, canvasInitImage, canvasMaskImage);
}
}

View File

@ -1,7 +1,13 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageDTO, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
import {
type ImageDTO,
type ImageToLatentsInvocation,
isNonRefinerMainModelConfig,
type NonNullableGraph,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -25,12 +31,15 @@ import {
POSITIVE_CONDITIONING,
SEAMLESS,
} from './constants';
import { addCoreMetadataNode } from './metadata';
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
/**
* Builds the Canvas tab's Image to Image graph.
*/
export const buildCanvasImageToImageGraph = (state: RootState, initialImage: ImageDTO): NonNullableGraph => {
export const buildCanvasImageToImageGraph = async (
state: RootState,
initialImage: ImageDTO
): Promise<NonNullableGraph> => {
const log = logger('nodes');
const {
positivePrompt,
@ -306,6 +315,8 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima
});
}
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
addCoreMetadataNode(
graph,
{
@ -316,7 +327,7 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima
height: !isUsingScaledDimensions ? height : scaledBoundingBoxDimensions.height,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model,
model: getModelMetadataField(modelConfig),
seed,
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
@ -335,17 +346,17 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima
}
// add LoRA support
addLoRAsToGraph(state, graph, DENOISE_LATENTS);
await addLoRAsToGraph(state, graph, DENOISE_LATENTS);
// optionally add custom VAE
addVAEToGraph(state, graph, modelLoaderNodeId);
await addVAEToGraph(state, graph, modelLoaderNodeId);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -40,11 +40,11 @@ import { getBoardField, getIsIntermediate } from './graphBuilderUtils';
/**
* Builds the Canvas tab's Inpaint graph.
*/
export const buildCanvasInpaintGraph = (
export const buildCanvasInpaintGraph = async (
state: RootState,
canvasInitImage: ImageDTO,
canvasMaskImage: ImageDTO
): NonNullableGraph => {
): Promise<NonNullableGraph> => {
const log = logger('nodes');
const {
positivePrompt,
@ -414,17 +414,17 @@ export const buildCanvasInpaintGraph = (
}
// Add VAE
addVAEToGraph(state, graph, modelLoaderNodeId);
await addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!

View File

@ -44,11 +44,11 @@ import { getBoardField, getIsIntermediate } from './graphBuilderUtils';
/**
* Builds the Canvas tab's Outpaint graph.
*/
export const buildCanvasOutpaintGraph = (
export const buildCanvasOutpaintGraph = async (
state: RootState,
canvasInitImage: ImageDTO,
canvasMaskImage?: ImageDTO
): NonNullableGraph => {
): Promise<NonNullableGraph> => {
const log = logger('nodes');
const {
positivePrompt,
@ -545,18 +545,18 @@ export const buildCanvasOutpaintGraph = (
}
// Add VAE
addVAEToGraph(state, graph, modelLoaderNodeId);
await addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -1,6 +1,12 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import type { ImageDTO, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import {
type ImageDTO,
type ImageToLatentsInvocation,
isNonRefinerMainModelConfig,
type NonNullableGraph,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -26,12 +32,15 @@ import {
SEAMLESS,
} from './constants';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
import { addCoreMetadataNode } from './metadata';
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
/**
* Builds the Canvas tab's Image to Image graph.
*/
export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage: ImageDTO): NonNullableGraph => {
export const buildCanvasSDXLImageToImageGraph = async (
state: RootState,
initialImage: ImageDTO
): Promise<NonNullableGraph> => {
const log = logger('nodes');
const {
positivePrompt,
@ -307,6 +316,8 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage:
});
}
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
addCoreMetadataNode(
graph,
{
@ -317,7 +328,7 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage:
height: !isUsingScaledDimensions ? height : scaledBoundingBoxDimensions.height,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model,
model: getModelMetadataField(modelConfig),
seed,
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
@ -338,24 +349,24 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage:
// Add Refiner if enabled
if (refinerModel) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
}
// optionally add custom VAE
addVAEToGraph(state, graph, modelLoaderNodeId);
await addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -41,11 +41,11 @@ import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBu
/**
* Builds the Canvas tab's Inpaint graph.
*/
export const buildCanvasSDXLInpaintGraph = (
export const buildCanvasSDXLInpaintGraph = async (
state: RootState,
canvasInitImage: ImageDTO,
canvasMaskImage: ImageDTO
): NonNullableGraph => {
): Promise<NonNullableGraph> => {
const log = logger('nodes');
const {
positivePrompt,
@ -426,24 +426,31 @@ export const buildCanvasSDXLInpaintGraph = (
// Add Refiner if enabled
if (refinerModel) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId, canvasInitImage, canvasMaskImage);
await addSDXLRefinerToGraph(
state,
graph,
SDXL_DENOISE_LATENTS,
modelLoaderNodeId,
canvasInitImage,
canvasMaskImage
);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
}
// Add VAE
addVAEToGraph(state, graph, modelLoaderNodeId);
await addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -45,11 +45,11 @@ import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBu
/**
* Builds the Canvas tab's Outpaint graph.
*/
export const buildCanvasSDXLOutpaintGraph = (
export const buildCanvasSDXLOutpaintGraph = async (
state: RootState,
canvasInitImage: ImageDTO,
canvasMaskImage?: ImageDTO
): NonNullableGraph => {
): Promise<NonNullableGraph> => {
const log = logger('nodes');
const {
positivePrompt,
@ -555,25 +555,25 @@ export const buildCanvasSDXLOutpaintGraph = (
// Add Refiner if enabled
if (refinerModel) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId, canvasInitImage);
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId, canvasInitImage);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
}
// Add VAE
addVAEToGraph(state, graph, modelLoaderNodeId);
await addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -1,6 +1,7 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import type { NonNullableGraph } from 'services/api/types';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -24,12 +25,12 @@ import {
SEAMLESS,
} from './constants';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
import { addCoreMetadataNode } from './metadata';
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
/**
* Builds the Canvas tab's Text to Image graph.
*/
export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGraph => {
export const buildCanvasSDXLTextToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
const log = logger('nodes');
const {
positivePrompt,
@ -272,6 +273,8 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr
});
}
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
addCoreMetadataNode(
graph,
{
@ -284,7 +287,7 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr
negative_prompt: negativePrompt,
positive_style_prompt: positiveStylePrompt,
negative_style_prompt: negativeStylePrompt,
model,
model: getModelMetadataField(modelConfig),
seed,
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
@ -301,24 +304,24 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr
// Add Refiner if enabled
if (refinerModel) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
}
// add LoRA support
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
// optionally add custom VAE
addVAEToGraph(state, graph, modelLoaderNodeId);
await addVAEToGraph(state, graph, modelLoaderNodeId);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -1,7 +1,8 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import type { NonNullableGraph } from 'services/api/types';
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -23,12 +24,12 @@ import {
POSITIVE_CONDITIONING,
SEAMLESS,
} from './constants';
import { addCoreMetadataNode } from './metadata';
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
/**
* Builds the Canvas tab's Text to Image graph.
*/
export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph => {
export const buildCanvasTextToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
const log = logger('nodes');
const {
positivePrompt,
@ -262,6 +263,8 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph
});
}
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
addCoreMetadataNode(
graph,
{
@ -272,7 +275,7 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph
height: !isUsingScaledDimensions ? height : scaledBoundingBoxDimensions.height,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model,
model: getModelMetadataField(modelConfig),
seed,
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
@ -289,17 +292,17 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph
}
// optionally add custom VAE
addVAEToGraph(state, graph, modelLoaderNodeId);
await addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -1,7 +1,13 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
import {
type ImageResizeInvocation,
type ImageToLatentsInvocation,
isNonRefinerMainModelConfig,
type NonNullableGraph,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -24,12 +30,12 @@ import {
RESIZE,
SEAMLESS,
} from './constants';
import { addCoreMetadataNode } from './metadata';
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
/**
* Builds the Image to Image tab graph.
*/
export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph => {
export const buildLinearImageToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
const log = logger('nodes');
const {
positivePrompt,
@ -307,6 +313,8 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph
});
}
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
addCoreMetadataNode(
graph,
{
@ -317,7 +325,7 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph
width,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model,
model: getModelMetadataField(modelConfig),
seed,
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
@ -336,17 +344,17 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph
}
// optionally add custom VAE
addVAEToGraph(state, graph, modelLoaderNodeId);
await addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -1,6 +1,12 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import {
type ImageResizeInvocation,
type ImageToLatentsInvocation,
isNonRefinerMainModelConfig,
type NonNullableGraph,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -25,12 +31,12 @@ import {
SEAMLESS,
} from './constants';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
import { addCoreMetadataNode } from './metadata';
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
/**
* Builds the Image to Image tab graph.
*/
export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableGraph => {
export const buildLinearSDXLImageToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
const log = logger('nodes');
const {
positivePrompt,
@ -318,6 +324,8 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG
});
}
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
addCoreMetadataNode(
graph,
{
@ -328,7 +336,7 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG
width,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model,
model: getModelMetadataField(modelConfig),
seed,
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
@ -349,25 +357,25 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG
// Add Refiner if enabled
if (refinerModel) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
}
// optionally add custom VAE
addVAEToGraph(state, graph, modelLoaderNodeId);
await addVAEToGraph(state, graph, modelLoaderNodeId);
// Add LoRA Support
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -1,6 +1,7 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import type { NonNullableGraph } from 'services/api/types';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -23,9 +24,9 @@ import {
SEAMLESS,
} from './constants';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
import { addCoreMetadataNode } from './metadata';
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGraph => {
export const buildLinearSDXLTextToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
const log = logger('nodes');
const {
positivePrompt,
@ -221,6 +222,8 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr
],
};
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
addCoreMetadataNode(
graph,
{
@ -231,7 +234,7 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr
width,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model,
model: getModelMetadataField(modelConfig),
seed,
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
@ -250,25 +253,25 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr
// Add Refiner if enabled
if (refinerModel) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
}
// optionally add custom VAE
addVAEToGraph(state, graph, modelLoaderNodeId);
await addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -1,7 +1,8 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import type { NonNullableGraph } from 'services/api/types';
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addHrfToGraph } from './addHrfToGraph';
@ -23,9 +24,9 @@ import {
SEAMLESS,
TEXT_TO_IMAGE_GRAPH,
} from './constants';
import { addCoreMetadataNode } from './metadata';
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph => {
export const buildLinearTextToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
const log = logger('nodes');
const {
positivePrompt,
@ -212,6 +213,8 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph
],
};
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
addCoreMetadataNode(
graph,
{
@ -222,7 +225,7 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph
width,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model,
model: getModelMetadataField(modelConfig),
seed,
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
@ -239,18 +242,18 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph
}
// optionally add custom VAE
addVAEToGraph(state, graph, modelLoaderNodeId);
await addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// High resolution fix.
if (state.hrf.hrfEnabled) {

View File

@ -1,5 +1,5 @@
import type { JSONObject } from 'common/types';
import type { CoreMetadataInvocation, NonNullableGraph } from 'services/api/types';
import type { AnyModelConfig, CoreMetadataInvocation, ModelMetadataField, NonNullableGraph } from 'services/api/types';
import { METADATA } from './constants';
@ -71,3 +71,11 @@ export const setMetadataReceivingNode = (graph: NonNullableGraph, nodeId: string
},
});
};
export const getModelMetadataField = ({ key, hash, name, base, type }: AnyModelConfig): ModelMetadataField => ({
key,
hash,
name,
base,
type,
});

View File

@ -38,6 +38,7 @@ export const ParamNegativePrompt = memo(() => {
onKeyDown={onKeyDown}
fontSize="sm"
variant="darkFilled"
paddingRight={30}
/>
<PromptOverlayButtonWrapper>
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />

View File

@ -54,6 +54,7 @@ export const ParamPositivePrompt = memo(() => {
minH={28}
onKeyDown={onKeyDown}
variant="darkFilled"
paddingRight={30}
/>
<PromptOverlayButtonWrapper>
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />

View File

@ -41,6 +41,7 @@ export const ParamSDXLNegativeStylePrompt = memo(() => {
onKeyDown={onKeyDown}
fontSize="sm"
variant="darkFilled"
paddingRight={30}
/>
<PromptOverlayButtonWrapper>
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />

View File

@ -38,6 +38,7 @@ export const ParamSDXLPositiveStylePrompt = memo(() => {
onKeyDown={onKeyDown}
fontSize="sm"
variant="darkFilled"
paddingRight={30}
/>
<PromptOverlayButtonWrapper>
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />

View File

@ -109,7 +109,7 @@ export const imagesApi = api.injectEndpoints({
}),
clearIntermediates: build.mutation<number, void>({
query: () => ({ url: buildImagesUrl('intermediates'), method: 'DELETE' }),
invalidatesTags: ['IntermediatesCount'],
invalidatesTags: ['IntermediatesCount', 'InvocationCacheStatus'],
}),
getImageDTO: build.query<ImageDTO, string>({
query: (image_name) => ({ url: buildImagesUrl(`i/${image_name}`) }),
@ -125,7 +125,7 @@ export const imagesApi = api.injectEndpoints({
paths['/api/v1/images/i/{image_name}/workflow']['get']['responses']['200']['content']['application/json'],
string
>({
query: (image_name) => ({ url: `images/i/${image_name}/workflow` }),
query: (image_name) => ({ url: buildImagesUrl(`i/${image_name}/workflow`) }),
providesTags: (result, error, image_name) => [{ type: 'ImageWorkflow', id: image_name }],
keepUnusedDataFor: 86400, // 24 hours
}),

View File

@ -23,7 +23,14 @@ export type UpdateModelArg = {
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
};
export type UpdateModelImageArg = {
key: string;
image: Blob;
};
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
type UpdateModelImageResponse =
paths['/api/v2/models/i/{key}/image']['patch']['responses']['200']['content']['application/json'];
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
@ -33,6 +40,7 @@ type DeleteModelArg = {
key: string;
};
type DeleteModelResponse = void;
type DeleteModelImageResponse = void;
type ConvertMainModelResponse =
paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];
@ -144,6 +152,18 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['Model'],
}),
updateModelImage: build.mutation<UpdateModelImageResponse, UpdateModelImageArg>({
query: ({ key, image }) => {
const formData = new FormData();
formData.append('image', image);
return {
url: buildModelsUrl(`i/${key}/image`),
method: 'PATCH',
body: formData,
};
},
invalidatesTags: ['Model'],
}),
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
query: ({ source }) => {
return {
@ -163,6 +183,18 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['Model'],
}),
deleteModelImage: build.mutation<DeleteModelImageResponse, string>({
query: (key) => {
return {
url: buildModelsUrl(`i/${key}/image`),
method: 'DELETE',
};
},
invalidatesTags: ['Model'],
}),
getModelImage: build.query<string, string>({
query: (key) => buildModelsUrl(`i/${key}/image`),
}),
convertModel: build.mutation<ConvertMainModelResponse, string>({
query: (key) => {
return {
@ -330,7 +362,9 @@ export const {
useGetTextualInversionModelsQuery,
useGetVaeModelsQuery,
useDeleteModelsMutation,
useDeleteModelImageMutation,
useUpdateModelMutation,
useUpdateModelImageMutation,
useInstallModelMutation,
useConvertModelMutation,
useSyncModelsMutation,

File diff suppressed because one or more lines are too long

View File

@ -38,7 +38,6 @@ export type OffsetPaginatedResults_ImageDTO_ = S['OffsetPaginatedResults_ImageDT
export type ModelType = S['ModelType'];
export type SubModelType = S['SubModelType'];
export type BaseModelType = S['BaseModelType'];
export type ControlField = S['ControlField'];
// Model Configs
@ -120,17 +119,18 @@ export type CreateGradientMaskInvocation = S['CreateGradientMaskInvocation'];
export type CanvasPasteBackInvocation = S['CanvasPasteBackInvocation'];
export type NoiseInvocation = S['NoiseInvocation'];
export type DenoiseLatentsInvocation = S['DenoiseLatentsInvocation'];
export type SDXLLoraLoaderInvocation = S['SDXLLoraLoaderInvocation'];
export type SDXLLoRALoaderInvocation = S['SDXLLoRALoaderInvocation'];
export type ImageToLatentsInvocation = S['ImageToLatentsInvocation'];
export type LatentsToImageInvocation = S['LatentsToImageInvocation'];
export type LoraLoaderInvocation = S['LoraLoaderInvocation'];
export type LoRALoaderInvocation = S['LoRALoaderInvocation'];
export type ESRGANInvocation = S['ESRGANInvocation'];
export type ImageNSFWBlurInvocation = S['ImageNSFWBlurInvocation'];
export type ImageWatermarkInvocation = S['ImageWatermarkInvocation'];
export type SeamlessModeInvocation = S['SeamlessModeInvocation'];
export type CoreMetadataInvocation = S['CoreMetadataInvocation'];
export type IPAdapterMetadataField = S['IPAdapterMetadataField'];
export type T2IAdapterField = S['T2IAdapterField'];
// Metadata fields
export type ModelMetadataField = S['ModelMetadataField'];
// ControlNet Nodes
export type ControlNetInvocation = S['ControlNetInvocation'];

View File

@ -33,19 +33,15 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.latent import SchedulerOutput
from invokeai.app.invocations.metadata import MetadataItemField, MetadataItemOutput, MetadataOutput
from invokeai.app.invocations.model import (
ClipField,
CLIPField,
CLIPOutput,
LoraInfo,
LoraLoaderOutput,
LoRAModelField,
MainModelField,
ModelInfo,
LoRALoaderOutput,
ModelField,
ModelLoaderOutput,
SDXLLoraLoaderOutput,
SDXLLoRALoaderOutput,
UNetField,
UNetOutput,
VaeField,
VAEModelField,
VAEField,
VAEOutput,
)
from invokeai.app.invocations.primitives import (
@ -73,8 +69,8 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.model_management.model_manager import LoadedModelInfo
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.config import BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
@ -118,20 +114,16 @@ __all__ = [
"MetadataItemOutput",
"MetadataOutput",
# invokeai.app.invocations.model
"ModelInfo",
"LoraInfo",
"ModelField",
"UNetField",
"ClipField",
"VaeField",
"MainModelField",
"LoRAModelField",
"VAEModelField",
"CLIPField",
"VAEField",
"UNetOutput",
"VAEOutput",
"CLIPOutput",
"ModelLoaderOutput",
"LoraLoaderOutput",
"SDXLLoraLoaderOutput",
"LoRALoaderOutput",
"SDXLLoRALoaderOutput",
# invokeai.app.invocations.primitives
"BooleanCollectionOutput",
"BooleanOutput",
@ -166,7 +158,7 @@ __all__ = [
# invokeai.app.services.config.config_default
"InvokeAIAppConfig",
# invokeai.backend.model_management.model_manager
"LoadedModelInfo",
"LoadedModel",
# invokeai.backend.model_management.models.base
"BaseModelType",
"ModelType",

View File

@ -0,0 +1,17 @@
import json
import os
import sys
def main():
# Change working directory to the repo root
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from invokeai.app.api_app import custom_openapi
schema = custom_openapi()
json.dump(schema, sys.stdout, indent=2)
if __name__ == "__main__":
main()

View File

@ -14,17 +14,13 @@ def test_loading(mm2_model_manager: ModelManagerServiceBase, embedding_file: Pat
matches = store.search_by_attr(model_name="test_embedding")
assert len(matches) == 0
key = mm2_model_manager.install.register_path(embedding_file)
loaded_model = mm2_model_manager.load_model_by_config(store.get_model(key))
loaded_model = mm2_model_manager.load.load_model(store.get_model(key))
assert loaded_model is not None
assert loaded_model.config.key == key
with loaded_model as model:
assert isinstance(model, TextualInversionModelRaw)
loaded_model_2 = mm2_model_manager.load_model_by_key(key)
assert loaded_model.config.key == loaded_model_2.config.key
loaded_model_3 = mm2_model_manager.load_model_by_attr(
model_name=loaded_model.config.name,
model_type=loaded_model.config.type,
base_model=loaded_model.config.base,
)
assert loaded_model.config.key == loaded_model_3.config.key
config = mm2_model_manager.store.get_model(key)
loaded_model_2 = mm2_model_manager.load.load_model(config)
assert loaded_model.config.key == loaded_model_2.config.key

View File

@ -44,6 +44,7 @@ def mock_services() -> InvocationServices:
images=ImageService(),
invocation_cache=MemoryInvocationCache(max_cache_size=0),
logger=logging, # type: ignore
model_images=None, # type: ignore
model_manager=None, # type: ignore
download_queue=None, # type: ignore
names=None, # type: ignore

Some files were not shown because too many files have changed in this diff Show More