mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
55 Commits
separate-g
...
maryhipp/u
Author | SHA1 | Date | |
---|---|---|---|
52039a8d91 | |||
dc7154439d | |||
4bbe6f3548 | |||
ad70cdfe87 | |||
549d461107 | |||
cab3748010 | |||
779b3e0e8e | |||
9b48029bc9 | |||
347f1fd0b7 | |||
4af5a09a68 | |||
8df02623f2 | |||
aa88fadc30 | |||
8411029d93 | |||
239b1e8cc7 | |||
8a68355926 | |||
86aef9f31d | |||
2f6964bfa5 | |||
c1cdfd132b | |||
f6bfe5e6f2 | |||
b5a8455b5f | |||
645ef081ea | |||
e68d7fa6d7 | |||
c5ab1c7ad6 | |||
5a561cab78 | |||
132790eebe | |||
c57f6ee885 | |||
d4a2ea68fc | |||
528ac5dd25 | |||
afd9ae7712 | |||
4eefed12f0 | |||
4301a3d6fd | |||
99c0662e3f | |||
cdc0d0c182 | |||
a00369a67a | |||
4f096ac3ba | |||
f5e3341465 | |||
474852ef7e | |||
b1d72d411e | |||
46614ee28f | |||
b019f9bb8b | |||
b857692073 | |||
90fb7a1a59 | |||
56fcf6af78 | |||
c4fe7e697b | |||
2fd483dfc8 | |||
b9a9507422 | |||
f2744fd7d1 | |||
fe6e879d38 | |||
d3ab08fe10 | |||
b0615bdfd4 | |||
bab20467fb | |||
e24624109e | |||
458e7185b8 | |||
a95128f5f2 | |||
46f32c5e3c |
8
Makefile
8
Makefile
@ -10,10 +10,11 @@ help:
|
||||
@echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting"
|
||||
@echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors"
|
||||
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
|
||||
@echo "test" Run the unit tests.
|
||||
@echo "frontend-install" Install the pnpm modules needed for the front end
|
||||
@echo "test Run the unit tests."
|
||||
@echo "frontend-install Install the pnpm modules needed for the front end"
|
||||
@echo "frontend-build Build the frontend in order to run on localhost:9090"
|
||||
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
|
||||
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
|
||||
@echo "installer-zip Build the installer .zip file for the current version"
|
||||
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
||||
|
||||
@ -53,6 +54,9 @@ frontend-build:
|
||||
frontend-dev:
|
||||
cd invokeai/frontend/web && pnpm dev
|
||||
|
||||
frontend-typegen:
|
||||
cd invokeai/frontend/web && python ../../../scripts/generate_openapi_schema.py | pnpm typegen
|
||||
|
||||
# Installer zip file
|
||||
installer-zip:
|
||||
cd installer && ./create_installer.sh
|
||||
|
@ -25,6 +25,7 @@ from ..services.invocation_cache.invocation_cache_memory import MemoryInvocation
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.model_images.model_images_default import ModelImageFileStorageDisk
|
||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||
from ..services.model_records import ModelRecordServiceSQL
|
||||
from ..services.names.names_default import SimpleNameService
|
||||
@ -71,6 +72,8 @@ class ApiDependencies:
|
||||
|
||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||
|
||||
model_images_folder = config.models_path
|
||||
|
||||
db = init_db(config=config, logger=logger, image_files=image_files)
|
||||
|
||||
configuration = config
|
||||
@ -92,6 +95,7 @@ class ApiDependencies:
|
||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||
)
|
||||
download_queue_service = DownloadQueueService(event_bus=events)
|
||||
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
||||
model_manager = ModelManagerService.build_model_manager(
|
||||
app_config=configuration,
|
||||
model_record_service=ModelRecordServiceSQL(db=db),
|
||||
@ -118,6 +122,7 @@ class ApiDependencies:
|
||||
images=images,
|
||||
invocation_cache=invocation_cache,
|
||||
logger=logger,
|
||||
model_images=model_images_service,
|
||||
model_manager=model_manager,
|
||||
download_queue=download_queue_service,
|
||||
names=names,
|
||||
|
@ -1,12 +1,16 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein
|
||||
"""FastAPI route for model configuration records."""
|
||||
|
||||
import io
|
||||
import pathlib
|
||||
import shutil
|
||||
import traceback
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi import Body, Path, Query, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
@ -31,6 +35,9 @@ from ..dependencies import ApiDependencies
|
||||
|
||||
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
|
||||
|
||||
# images are immutable; set a high max-age
|
||||
IMAGE_MAX_AGE = 31536000
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
"""Return list of configs."""
|
||||
@ -105,6 +112,9 @@ async def list_model_records(
|
||||
found_models.extend(
|
||||
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
|
||||
)
|
||||
for model in found_models:
|
||||
cover_image = ApiDependencies.invoker.services.model_images.get_url(model.key)
|
||||
model.cover_image = cover_image
|
||||
return ModelsList(models=found_models)
|
||||
|
||||
|
||||
@ -148,6 +158,8 @@ async def get_model_record(
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
try:
|
||||
config: AnyModelConfig = record_store.get_model(key)
|
||||
cover_image = ApiDependencies.invoker.services.model_images.get_url(key)
|
||||
config.cover_image = cover_image
|
||||
return config
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@ -266,6 +278,75 @@ async def update_model_record(
|
||||
return model_response
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/i/{key}/image",
|
||||
operation_id="get_model_image",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The model image was fetched successfully",
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The model image could not be found"},
|
||||
},
|
||||
status_code=200,
|
||||
)
|
||||
async def get_model_image(
|
||||
key: str = Path(description="The name of model image file to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets an image file that previews the model"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.model_images.get_path(key)
|
||||
|
||||
response = FileResponse(
|
||||
path,
|
||||
media_type="image/png",
|
||||
filename=key + ".png",
|
||||
content_disposition_type="inline",
|
||||
)
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
return response
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/i/{key}/image",
|
||||
operation_id="update_model_image",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The model image was updated successfully",
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
status_code=200,
|
||||
)
|
||||
async def update_model_image(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
image: UploadFile,
|
||||
) -> None:
|
||||
if not image.content_type or not image.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
contents = await image.read()
|
||||
try:
|
||||
pil_image = Image.open(io.BytesIO(contents))
|
||||
|
||||
except Exception:
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
model_images = ApiDependencies.invoker.services.model_images
|
||||
try:
|
||||
model_images.save(pil_image, key)
|
||||
logger.info(f"Updated image for model: {key}")
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
return
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/i/{key}",
|
||||
operation_id="delete_model",
|
||||
@ -296,6 +377,29 @@ async def delete_model(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/i/{key}/image",
|
||||
operation_id="delete_model_image",
|
||||
responses={
|
||||
204: {"description": "Model image deleted successfully"},
|
||||
404: {"description": "Model image not found"},
|
||||
},
|
||||
status_code=204,
|
||||
)
|
||||
async def delete_model_image(
|
||||
key: str = Path(description="Unique key of model image to remove from model_images directory."),
|
||||
) -> None:
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
model_images = ApiDependencies.invoker.services.model_images
|
||||
try:
|
||||
model_images.delete(key)
|
||||
logger.info(f"Deleted model image: {key}")
|
||||
return
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
# @model_manager_router.post(
|
||||
# "/i/",
|
||||
# operation_id="add_model_record",
|
||||
@ -539,7 +643,7 @@ async def convert_model(
|
||||
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
|
||||
|
||||
# loading the model will convert it into a cached diffusers file
|
||||
model_manager.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
|
||||
model_manager.load.load_model(model_config, submodel_type=SubModelType.Scheduler)
|
||||
|
||||
# Get the path of the converted model from the loader
|
||||
cache_path = loader.convert_cache.cache_path(key)
|
||||
|
@ -2,12 +2,9 @@
|
||||
# which are imported/used before parse_args() is called will get the default config values instead of the
|
||||
# values from the command line or config file.
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
||||
from .services.config import InvokeAIAppConfig
|
||||
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
@ -20,6 +17,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
import asyncio
|
||||
import mimetypes
|
||||
import socket
|
||||
from contextlib import asynccontextmanager
|
||||
from inspect import signature
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@ -40,6 +38,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
# noinspection PyUnresolvedReferences
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
import invokeai.frontend.web as web_dir
|
||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
from .api.dependencies import ApiDependencies
|
||||
@ -59,6 +58,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
BaseInvocation,
|
||||
UIConfigBase,
|
||||
)
|
||||
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
||||
|
||||
if is_mps_available():
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
|
@ -20,7 +20,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
from invokeai.backend.util.devices import torch_dtype
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from .model import ClipField
|
||||
from .model import CLIPField
|
||||
|
||||
# unconditioned: Optional[torch.Tensor]
|
||||
|
||||
@ -46,7 +46,7 @@ class CompelInvocation(BaseInvocation):
|
||||
description=FieldDescriptions.compel_prompt,
|
||||
ui_component=UIComponent.Textarea,
|
||||
)
|
||||
clip: ClipField = InputField(
|
||||
clip: CLIPField = InputField(
|
||||
title="CLIP",
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
@ -54,16 +54,16 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
|
||||
tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
tokenizer_model = tokenizer_info.model
|
||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
|
||||
text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
text_encoder_model = text_encoder_info.model
|
||||
assert isinstance(text_encoder_model, CLIPTextModel)
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
@ -127,16 +127,16 @@ class SDXLPromptInvocationBase:
|
||||
def run_clip_compel(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
clip_field: ClipField,
|
||||
clip_field: CLIPField,
|
||||
prompt: str,
|
||||
get_pooled: bool,
|
||||
lora_prefix: str,
|
||||
zero_on_empty: bool,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
|
||||
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
|
||||
tokenizer_info = context.models.load(clip_field.tokenizer)
|
||||
tokenizer_model = tokenizer_info.model
|
||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
|
||||
text_encoder_info = context.models.load(clip_field.text_encoder)
|
||||
text_encoder_model = text_encoder_info.model
|
||||
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||
|
||||
@ -163,7 +163,7 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||
lora_info = context.models.load(lora.lora)
|
||||
lora_model = lora_info.model
|
||||
assert isinstance(lora_model, LoRAModelRaw)
|
||||
yield (lora_model, lora.weight)
|
||||
@ -253,8 +253,8 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
crop_left: int = InputField(default=0, description="")
|
||||
target_width: int = InputField(default=1024, description="")
|
||||
target_height: int = InputField(default=1024, description="")
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
||||
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
||||
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
@ -340,7 +340,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
crop_top: int = InputField(default=0, description="")
|
||||
crop_left: int = InputField(default=0, description="")
|
||||
aesthetic_score: float = InputField(default=6.0, description=FieldDescriptions.sdxl_aesthetic)
|
||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
@ -370,10 +370,10 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
|
||||
|
||||
@invocation_output("clip_skip_output")
|
||||
class ClipSkipInvocationOutput(BaseInvocationOutput):
|
||||
"""Clip skip node output"""
|
||||
class CLIPSkipInvocationOutput(BaseInvocationOutput):
|
||||
"""CLIP skip node output"""
|
||||
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -383,15 +383,15 @@ class ClipSkipInvocationOutput(BaseInvocationOutput):
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ClipSkipInvocation(BaseInvocation):
|
||||
class CLIPSkipInvocation(BaseInvocation):
|
||||
"""Skip layers in clip text_encoder model."""
|
||||
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
|
||||
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
|
||||
skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
|
||||
def invoke(self, context: InvocationContext) -> CLIPSkipInvocationOutput:
|
||||
self.clip.skipped_layers += self.skipped_layers
|
||||
return ClipSkipInvocationOutput(
|
||||
return CLIPSkipInvocationOutput(
|
||||
clip=self.clip,
|
||||
)
|
||||
|
||||
|
@ -34,6 +34,7 @@ from invokeai.app.invocations.fields import (
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
@ -51,15 +52,9 @@ CONTROLNET_RESIZE_VALUES = Literal[
|
||||
]
|
||||
|
||||
|
||||
class ControlNetModelField(BaseModel):
|
||||
"""ControlNet model field"""
|
||||
|
||||
key: str = Field(description="Model config record key for the ControlNet model")
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
control_model: ControlNetModelField = Field(description="The ControlNet model to use")
|
||||
control_model: ModelField = Field(description="The ControlNet model to use")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
@ -95,7 +90,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
|
||||
control_model: ModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||
)
|
||||
|
@ -228,7 +228,7 @@ class ConditioningField(BaseModel):
|
||||
# endregion
|
||||
|
||||
|
||||
class MetadataField(RootModel):
|
||||
class MetadataField(RootModel[dict[str, Any]]):
|
||||
"""
|
||||
Pydantic model for metadata with custom root of type dict[str, Any].
|
||||
Metadata is stored without a strict schema.
|
||||
|
@ -11,25 +11,17 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import ModelField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||
|
||||
|
||||
# LS: Consider moving these two classes into model.py
|
||||
class IPAdapterModelField(BaseModel):
|
||||
key: str = Field(description="Key to the IP-Adapter model")
|
||||
|
||||
|
||||
class CLIPVisionModelField(BaseModel):
|
||||
key: str = Field(description="Key to the CLIP Vision image encoder model")
|
||||
from invokeai.backend.model_manager.config import BaseModelType, IPAdapterConfig, ModelType
|
||||
|
||||
|
||||
class IPAdapterField(BaseModel):
|
||||
image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).")
|
||||
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
|
||||
image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.")
|
||||
ip_adapter_model: ModelField = Field(description="The IP-Adapter model to use.")
|
||||
image_encoder_model: ModelField = Field(description="The name of the CLIP image encoder model.")
|
||||
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
||||
@ -62,7 +54,7 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
|
||||
ip_adapter_model: IPAdapterModelField = InputField(
|
||||
ip_adapter_model: ModelField = InputField(
|
||||
description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
|
||||
)
|
||||
|
||||
@ -90,18 +82,18 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
||||
assert isinstance(ip_adapter_info, IPAdapterConfig)
|
||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||
image_encoder_models = context.models.search_by_attrs(
|
||||
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
|
||||
)
|
||||
assert len(image_encoder_models) == 1
|
||||
image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key)
|
||||
return IPAdapterOutput(
|
||||
ip_adapter=IPAdapterField(
|
||||
image=self.image,
|
||||
ip_adapter_model=self.ip_adapter_model,
|
||||
image_encoder_model=image_encoder_model,
|
||||
image_encoder_model=ModelField(key=image_encoder_models[0].key),
|
||||
weight=self.weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
|
@ -26,6 +26,7 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from PIL import Image, ImageFilter
|
||||
from pydantic import field_validator
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
||||
from invokeai.app.invocations.fields import (
|
||||
@ -75,7 +76,7 @@ from .baseinvocation import (
|
||||
invocation_output,
|
||||
)
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
from .model import ModelField, UNetField, VAEField
|
||||
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
@ -118,7 +119,7 @@ class SchedulerInvocation(BaseInvocation):
|
||||
class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
"""Creates mask for denoising model run."""
|
||||
|
||||
vae: VaeField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
|
||||
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
|
||||
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
|
||||
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
|
||||
@ -153,7 +154,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
if image_tensor is not None:
|
||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
|
||||
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||
@ -244,12 +245,12 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
||||
|
||||
def get_scheduler(
|
||||
context: InvocationContext,
|
||||
scheduler_info: ModelInfo,
|
||||
scheduler_info: ModelField,
|
||||
scheduler_name: str,
|
||||
seed: int,
|
||||
) -> Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.models.load(**scheduler_info.model_dump())
|
||||
orig_scheduler_info = context.models.load(scheduler_info)
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
scheduler_config = orig_scheduler.config
|
||||
|
||||
@ -461,7 +462,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# and if weight is None, populate with default 1.0?
|
||||
controlnet_data = []
|
||||
for control_info in control_list:
|
||||
control_model = exit_stack.enter_context(context.models.load(key=control_info.control_model.key))
|
||||
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
|
||||
|
||||
# control_models.append(control_model)
|
||||
control_image_field = control_info.image
|
||||
@ -523,11 +524,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
conditioning_data.ip_adapter_conditioning = []
|
||||
for single_ip_adapter in ip_adapter:
|
||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||
context.models.load(key=single_ip_adapter.ip_adapter_model.key)
|
||||
context.models.load(single_ip_adapter.ip_adapter_model)
|
||||
)
|
||||
|
||||
image_encoder_model_info = context.models.load(key=single_ip_adapter.image_encoder_model.key)
|
||||
|
||||
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
|
||||
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
||||
single_ipa_image_fields = single_ip_adapter.image
|
||||
if not isinstance(single_ipa_image_fields, list):
|
||||
@ -538,6 +538,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||
with image_encoder_model_info as image_encoder_model:
|
||||
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
|
||||
# Get image embeddings from CLIP and ImageProjModel.
|
||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
||||
single_ipa_images, image_encoder_model
|
||||
@ -577,8 +578,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
t2i_adapter_data = []
|
||||
for t2i_adapter_field in t2i_adapter:
|
||||
t2i_adapter_model_config = context.models.get_config(key=t2i_adapter_field.t2i_adapter_model.key)
|
||||
t2i_adapter_loaded_model = context.models.load(key=t2i_adapter_field.t2i_adapter_model.key)
|
||||
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
|
||||
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
|
||||
image = context.images.get_pil(t2i_adapter_field.image.image_name)
|
||||
|
||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||
@ -731,12 +732,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.models.load(**self.unet.unet.model_dump())
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
@ -830,7 +832,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
vae: VaeField = InputField(
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
@ -841,8 +843,8 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL))
|
||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
latents = latents.to(vae.device)
|
||||
@ -1008,7 +1010,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
image: ImageField = InputField(
|
||||
description="The image to encode",
|
||||
)
|
||||
vae: VaeField = InputField(
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
@ -1064,7 +1066,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
|
@ -8,7 +8,10 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.controlnet_image_processors import (
|
||||
CONTROLNET_MODE_VALUES,
|
||||
CONTROLNET_RESIZE_VALUES,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
@ -17,10 +20,8 @@ from invokeai.app.invocations.fields import (
|
||||
OutputField,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
|
||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||
|
||||
from ...version import __version__
|
||||
|
||||
@ -30,10 +31,20 @@ class MetadataItemField(BaseModel):
|
||||
value: Any = Field(description=FieldDescriptions.metadata_item_value)
|
||||
|
||||
|
||||
class ModelMetadataField(BaseModel):
|
||||
"""Model Metadata Field"""
|
||||
|
||||
key: str
|
||||
hash: str
|
||||
name: str
|
||||
base: BaseModelType
|
||||
type: ModelType
|
||||
|
||||
|
||||
class LoRAMetadataField(BaseModel):
|
||||
"""LoRA Metadata Field"""
|
||||
|
||||
model: LoRAModelField = Field(description=FieldDescriptions.lora_model)
|
||||
model: ModelMetadataField = Field(description=FieldDescriptions.lora_model)
|
||||
weight: float = Field(description=FieldDescriptions.lora_weight)
|
||||
|
||||
|
||||
@ -41,7 +52,7 @@ class IPAdapterMetadataField(BaseModel):
|
||||
"""IP Adapter Field, minus the CLIP Vision Encoder model"""
|
||||
|
||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||
ip_adapter_model: IPAdapterModelField = Field(
|
||||
ip_adapter_model: ModelMetadataField = Field(
|
||||
description="The IP-Adapter model.",
|
||||
)
|
||||
weight: Union[float, list[float]] = Field(
|
||||
@ -51,6 +62,33 @@ class IPAdapterMetadataField(BaseModel):
|
||||
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
|
||||
|
||||
|
||||
class T2IAdapterMetadataField(BaseModel):
|
||||
image: ImageField = Field(description="The T2I-Adapter image prompt.")
|
||||
t2i_adapter_model: ModelMetadataField = Field(description="The T2I-Adapter model to use.")
|
||||
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)"
|
||||
)
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
|
||||
class ControlNetMetadataField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
control_model: ModelMetadataField = Field(description="The ControlNet model to use")
|
||||
control_weight: Union[float, list[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
|
||||
@invocation_output("metadata_item_output")
|
||||
class MetadataItemOutput(BaseInvocationOutput):
|
||||
"""Metadata Item Output"""
|
||||
@ -140,14 +178,14 @@ class CoreMetadataInvocation(BaseInvocation):
|
||||
default=None,
|
||||
description="The number of skipped CLIP layers",
|
||||
)
|
||||
model: Optional[MainModelField] = InputField(default=None, description="The main model used for inference")
|
||||
controlnets: Optional[list[ControlField]] = InputField(
|
||||
model: Optional[ModelMetadataField] = InputField(default=None, description="The main model used for inference")
|
||||
controlnets: Optional[list[ControlNetMetadataField]] = InputField(
|
||||
default=None, description="The ControlNets used for inference"
|
||||
)
|
||||
ipAdapters: Optional[list[IPAdapterMetadataField]] = InputField(
|
||||
default=None, description="The IP Adapters used for inference"
|
||||
)
|
||||
t2iAdapters: Optional[list[T2IAdapterField]] = InputField(
|
||||
t2iAdapters: Optional[list[T2IAdapterMetadataField]] = InputField(
|
||||
default=None, description="The IP Adapters used for inference"
|
||||
)
|
||||
loras: Optional[list[LoRAMetadataField]] = InputField(default=None, description="The LoRAs used for inference")
|
||||
@ -159,7 +197,7 @@ class CoreMetadataInvocation(BaseInvocation):
|
||||
default=None,
|
||||
description="The name of the initial image",
|
||||
)
|
||||
vae: Optional[VAEModelField] = InputField(
|
||||
vae: Optional[ModelMetadataField] = InputField(
|
||||
default=None,
|
||||
description="The VAE used for decoding, if the main model's default was not used",
|
||||
)
|
||||
@ -190,7 +228,7 @@ class CoreMetadataInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
# SDXL Refiner
|
||||
refiner_model: Optional[MainModelField] = InputField(
|
||||
refiner_model: Optional[ModelMetadataField] = InputField(
|
||||
default=None,
|
||||
description="The SDXL Refiner model used",
|
||||
)
|
||||
@ -222,10 +260,9 @@ class CoreMetadataInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> MetadataOutput:
|
||||
"""Collects and outputs a CoreMetadata object"""
|
||||
|
||||
return MetadataOutput(
|
||||
metadata=MetadataField.model_validate(
|
||||
self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
|
||||
)
|
||||
)
|
||||
as_dict = self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
|
||||
as_dict["app_version"] = __version__
|
||||
|
||||
return MetadataOutput(metadata=MetadataField.model_validate(as_dict))
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
@ -6,8 +6,8 @@ from pydantic import BaseModel, Field
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.model_manager.config import SubModelType
|
||||
|
||||
from ...backend.model_manager import SubModelType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@ -16,33 +16,33 @@ from .baseinvocation import (
|
||||
)
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()")
|
||||
submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
||||
class ModelField(BaseModel):
|
||||
key: str = Field(description="Key of the model")
|
||||
submodel_type: Optional[SubModelType] = Field(description="Submodel type", default=None)
|
||||
|
||||
|
||||
class LoraInfo(ModelInfo):
|
||||
weight: float = Field(description="Lora's weight which to use when apply to model")
|
||||
class LoRAField(BaseModel):
|
||||
lora: ModelField = Field(description="Info to load lora model")
|
||||
weight: float = Field(description="Weight to apply to lora model")
|
||||
|
||||
|
||||
class UNetField(BaseModel):
|
||||
unet: ModelInfo = Field(description="Info to load unet submodel")
|
||||
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
|
||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||
unet: ModelField = Field(description="Info to load unet submodel")
|
||||
scheduler: ModelField = Field(description="Info to load scheduler submodel")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
|
||||
|
||||
|
||||
class ClipField(BaseModel):
|
||||
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
|
||||
class CLIPField(BaseModel):
|
||||
tokenizer: ModelField = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelField = Field(description="Info to load text_encoder submodel")
|
||||
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
|
||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
|
||||
|
||||
class VaeField(BaseModel):
|
||||
# TODO: better naming?
|
||||
vae: ModelInfo = Field(description="Info to load vae submodel")
|
||||
class VAEField(BaseModel):
|
||||
vae: ModelField = Field(description="Info to load vae submodel")
|
||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||
|
||||
|
||||
@ -57,14 +57,14 @@ class UNetOutput(BaseInvocationOutput):
|
||||
class VAEOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a VAE field"""
|
||||
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation_output("clip_output")
|
||||
class CLIPOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a CLIP field"""
|
||||
|
||||
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation_output("model_loader_output")
|
||||
@ -74,18 +74,6 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
|
||||
pass
|
||||
|
||||
|
||||
class MainModelField(BaseModel):
|
||||
"""Main model field"""
|
||||
|
||||
key: str = Field(description="Model key")
|
||||
|
||||
|
||||
class LoRAModelField(BaseModel):
|
||||
"""LoRA model field"""
|
||||
|
||||
key: str = Field(description="LoRA model key")
|
||||
|
||||
|
||||
@invocation(
|
||||
"main_model_loader",
|
||||
title="Main Model",
|
||||
@ -96,62 +84,40 @@ class LoRAModelField(BaseModel):
|
||||
class MainModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
|
||||
model: ModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
|
||||
# TODO: precision?
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
key = self.model.key
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.models.exists(key):
|
||||
raise Exception(f"Unknown model {key}")
|
||||
if not context.models.exists(self.model.key):
|
||||
raise Exception(f"Unknown model {self.model.key}")
|
||||
|
||||
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
|
||||
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
|
||||
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
return ModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.VAE,
|
||||
),
|
||||
),
|
||||
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
|
||||
vae=VAEField(vae=vae),
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("lora_loader_output")
|
||||
class LoraLoaderOutput(BaseInvocationOutput):
|
||||
class LoRALoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.1")
|
||||
class LoraLoaderInvocation(BaseInvocation):
|
||||
class LoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
default=None,
|
||||
@ -159,46 +125,41 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
clip: Optional[ClipField] = InputField(
|
||||
clip: Optional[CLIPField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LoRALoaderOutput:
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(lora_key):
|
||||
raise Exception(f"Unkown lora: {lora_key}!")
|
||||
|
||||
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to unet')
|
||||
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to unet')
|
||||
|
||||
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to clip')
|
||||
if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to clip')
|
||||
|
||||
output = LoraLoaderOutput()
|
||||
output = LoRALoaderOutput()
|
||||
|
||||
if self.unet is not None:
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet = self.unet.model_copy(deep=True)
|
||||
output.unet.loras.append(
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
if self.clip is not None:
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip = self.clip.model_copy(deep=True)
|
||||
output.clip.loras.append(
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
@ -207,12 +168,12 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@invocation_output("sdxl_lora_loader_output")
|
||||
class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||
class SDXLLoRALoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL LoRA Loader Output"""
|
||||
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
|
||||
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
||||
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
|
||||
clip2: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -222,10 +183,10 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||
category="model",
|
||||
version="1.0.1",
|
||||
)
|
||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
class SDXLLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
default=None,
|
||||
@ -233,65 +194,59 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
clip: Optional[ClipField] = InputField(
|
||||
clip: Optional[CLIPField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP 1",
|
||||
)
|
||||
clip2: Optional[ClipField] = InputField(
|
||||
clip2: Optional[CLIPField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP 2",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLLoRALoaderOutput:
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(lora_key):
|
||||
raise Exception(f"Unknown lora: {lora_key}!")
|
||||
|
||||
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to unet')
|
||||
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to unet')
|
||||
|
||||
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to clip')
|
||||
if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to clip')
|
||||
|
||||
if self.clip2 is not None and any(lora.key == lora_key for lora in self.clip2.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to clip2')
|
||||
if self.clip2 is not None and any(lora.lora.key == lora_key for lora in self.clip2.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to clip2')
|
||||
|
||||
output = SDXLLoraLoaderOutput()
|
||||
output = SDXLLoRALoaderOutput()
|
||||
|
||||
if self.unet is not None:
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet = self.unet.model_copy(deep=True)
|
||||
output.unet.loras.append(
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
if self.clip is not None:
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip = self.clip.model_copy(deep=True)
|
||||
output.clip.loras.append(
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
if self.clip2 is not None:
|
||||
output.clip2 = copy.deepcopy(self.clip2)
|
||||
output.clip2 = self.clip2.model_copy(deep=True)
|
||||
output.clip2.loras.append(
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
@ -299,17 +254,11 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
return output
|
||||
|
||||
|
||||
class VAEModelField(BaseModel):
|
||||
"""Vae model field"""
|
||||
|
||||
key: str = Field(description="Model's key")
|
||||
|
||||
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1")
|
||||
class VaeLoaderInvocation(BaseInvocation):
|
||||
class VAELoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
|
||||
vae_model: VAEModelField = InputField(
|
||||
vae_model: ModelField = InputField(
|
||||
description=FieldDescriptions.vae_model,
|
||||
input=Input.Direct,
|
||||
title="VAE",
|
||||
@ -321,7 +270,7 @@ class VaeLoaderInvocation(BaseInvocation):
|
||||
if not context.models.exists(key):
|
||||
raise Exception(f"Unkown vae: {key}!")
|
||||
|
||||
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))
|
||||
return VAEOutput(vae=VAEField(vae=self.vae_model))
|
||||
|
||||
|
||||
@invocation_output("seamless_output")
|
||||
@ -329,7 +278,7 @@ class SeamlessModeOutput(BaseInvocationOutput):
|
||||
"""Modified Seamless Model output"""
|
||||
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
vae: Optional[VaeField] = OutputField(default=None, description=FieldDescriptions.vae, title="VAE")
|
||||
vae: Optional[VAEField] = OutputField(default=None, description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -348,7 +297,7 @@ class SeamlessModeInvocation(BaseInvocation):
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
vae: Optional[VaeField] = InputField(
|
||||
vae: Optional[VAEField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.vae_model,
|
||||
input=Input.Connection,
|
||||
|
@ -8,7 +8,7 @@ from .baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField
|
||||
from .model import CLIPField, ModelField, UNetField, VAEField
|
||||
|
||||
|
||||
@invocation_output("sdxl_model_loader_output")
|
||||
@ -16,9 +16,9 @@ class SDXLModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL base model loader output"""
|
||||
|
||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
|
||||
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
|
||||
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation_output("sdxl_refiner_model_loader_output")
|
||||
@ -26,15 +26,15 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL refiner model loader output"""
|
||||
|
||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.1")
|
||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl base model, outputting its submodels."""
|
||||
|
||||
model: MainModelField = InputField(
|
||||
model: ModelField = InputField(
|
||||
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
|
||||
)
|
||||
# TODO: precision?
|
||||
@ -46,48 +46,19 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
if not context.models.exists(model_key):
|
||||
raise Exception(f"Unknown model: {model_key}")
|
||||
|
||||
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
|
||||
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
|
||||
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
return SDXLModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
clip2=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Tokenizer2,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.TextEncoder2,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.VAE,
|
||||
),
|
||||
),
|
||||
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
|
||||
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
|
||||
vae=VAEField(vae=vae),
|
||||
)
|
||||
|
||||
|
||||
@ -101,10 +72,8 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||
|
||||
model: MainModelField = InputField(
|
||||
description=FieldDescriptions.sdxl_refiner_model,
|
||||
input=Input.Direct,
|
||||
ui_type=UIType.SDXLRefinerModel,
|
||||
model: ModelField = InputField(
|
||||
description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel
|
||||
)
|
||||
# TODO: precision?
|
||||
|
||||
@ -115,34 +84,14 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
if not context.models.exists(model_key):
|
||||
raise Exception(f"Unknown model: {model_key}")
|
||||
|
||||
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
|
||||
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
|
||||
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
return SDXLRefinerModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip2=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Tokenizer2,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.TextEncoder2,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.VAE,
|
||||
),
|
||||
),
|
||||
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
|
||||
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
|
||||
vae=VAEField(vae=vae),
|
||||
)
|
||||
|
@ -10,17 +10,14 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import ModelField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
class T2IAdapterModelField(BaseModel):
|
||||
key: str = Field(description="Model record key for the T2I-Adapter model")
|
||||
|
||||
|
||||
class T2IAdapterField(BaseModel):
|
||||
image: ImageField = Field(description="The T2I-Adapter image prompt.")
|
||||
t2i_adapter_model: T2IAdapterModelField = Field(description="The T2I-Adapter model to use.")
|
||||
t2i_adapter_model: ModelField = Field(description="The T2I-Adapter model to use.")
|
||||
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
|
||||
@ -55,7 +52,7 @@ class T2IAdapterInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The IP-Adapter image prompt.")
|
||||
t2i_adapter_model: T2IAdapterModelField = InputField(
|
||||
t2i_adapter_model: ModelField = InputField(
|
||||
description="The T2I-Adapter model.",
|
||||
title="T2I-Adapter Model",
|
||||
input=Input.Direct,
|
||||
|
@ -41,8 +41,9 @@ class InvocationCacheBase(ABC):
|
||||
"""Clears the cache"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def create_key(self, invocation: BaseInvocation) -> int:
|
||||
def create_key(invocation: BaseInvocation) -> int:
|
||||
"""Gets the key for the invocation's cache item"""
|
||||
pass
|
||||
|
||||
|
@ -61,9 +61,7 @@ class MemoryInvocationCache(InvocationCacheBase):
|
||||
self._delete_oldest_access(number_to_delete)
|
||||
self._cache[key] = CachedItem(
|
||||
invocation_output,
|
||||
invocation_output.model_dump_json(
|
||||
warnings=False, exclude_defaults=True, exclude_unset=True, include={"type"}
|
||||
),
|
||||
invocation_output.model_dump_json(warnings=False, exclude_defaults=True, exclude_unset=True),
|
||||
)
|
||||
|
||||
def _delete_oldest_access(self, number_to_delete: int) -> None:
|
||||
@ -81,7 +79,7 @@ class MemoryInvocationCache(InvocationCacheBase):
|
||||
with self._lock:
|
||||
return self._delete(key)
|
||||
|
||||
def clear(self, *args, **kwargs) -> None:
|
||||
def clear(self) -> None:
|
||||
with self._lock:
|
||||
if self._max_cache_size == 0:
|
||||
return
|
||||
|
@ -25,6 +25,7 @@ if TYPE_CHECKING:
|
||||
from .images.images_base import ImageServiceABC
|
||||
from .invocation_cache.invocation_cache_base import InvocationCacheBase
|
||||
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
|
||||
from .model_images.model_images_base import ModelImageFileStorageBase
|
||||
from .model_manager.model_manager_base import ModelManagerServiceBase
|
||||
from .names.names_base import NameServiceBase
|
||||
from .session_processor.session_processor_base import SessionProcessorBase
|
||||
@ -49,6 +50,7 @@ class InvocationServices:
|
||||
image_files: "ImageFileStorageBase",
|
||||
image_records: "ImageRecordStorageBase",
|
||||
logger: "Logger",
|
||||
model_images: "ModelImageFileStorageBase",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
download_queue: "DownloadQueueServiceBase",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
@ -72,6 +74,7 @@ class InvocationServices:
|
||||
self.image_files = image_files
|
||||
self.image_records = image_records
|
||||
self.logger = logger
|
||||
self.model_images = model_images
|
||||
self.model_manager = model_manager
|
||||
self.download_queue = download_queue
|
||||
self.performance_statistics = performance_statistics
|
||||
|
33
invokeai/app/services/model_images/model_images_base.py
Normal file
33
invokeai/app/services/model_images/model_images_base.py
Normal file
@ -0,0 +1,33 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
|
||||
class ModelImageFileStorageBase(ABC):
|
||||
"""Low-level service responsible for storing and retrieving image files."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, model_key: str) -> PILImageType:
|
||||
"""Retrieves a model image as PIL Image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, model_key: str) -> Path:
|
||||
"""Gets the internal path to a model image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_url(self, model_key: str) -> str | None:
|
||||
"""Gets the URL to fetch a model image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, image: PILImageType, model_key: str) -> None:
|
||||
"""Saves a model image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, model_key: str) -> None:
|
||||
"""Deletes a model image."""
|
||||
pass
|
20
invokeai/app/services/model_images/model_images_common.py
Normal file
20
invokeai/app/services/model_images/model_images_common.py
Normal file
@ -0,0 +1,20 @@
|
||||
# TODO: Should these excpetions subclass existing python exceptions?
|
||||
class ModelImageFileNotFoundException(Exception):
|
||||
"""Raised when an image file is not found in storage."""
|
||||
|
||||
def __init__(self, message="Model image file not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ModelImageFileSaveException(Exception):
|
||||
"""Raised when an image cannot be saved."""
|
||||
|
||||
def __init__(self, message="Model image file not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ModelImageFileDeleteException(Exception):
|
||||
"""Raised when an image cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Model image file not deleted"):
|
||||
super().__init__(message)
|
79
invokeai/app/services/model_images/model_images_default.py
Normal file
79
invokeai/app/services/model_images/model_images_default.py
Normal file
@ -0,0 +1,79 @@
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as PILImageType
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.util.thumbnails import make_thumbnail
|
||||
|
||||
from .model_images_base import ModelImageFileStorageBase
|
||||
from .model_images_common import (
|
||||
ModelImageFileDeleteException,
|
||||
ModelImageFileNotFoundException,
|
||||
ModelImageFileSaveException,
|
||||
)
|
||||
|
||||
|
||||
class ModelImageFileStorageDisk(ModelImageFileStorageBase):
|
||||
"""Stores images on disk"""
|
||||
|
||||
def __init__(self, model_images_folder: Path):
|
||||
self._model_images_folder = model_images_folder
|
||||
self._validate_storage_folders()
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
|
||||
def get(self, model_key: str) -> PILImageType:
|
||||
try:
|
||||
path = self.get_path(model_key)
|
||||
|
||||
if not self._validate_path(path):
|
||||
raise ModelImageFileNotFoundException
|
||||
|
||||
return Image.open(path)
|
||||
except FileNotFoundError as e:
|
||||
raise ModelImageFileNotFoundException from e
|
||||
|
||||
def save(self, image: PILImageType, model_key: str) -> None:
|
||||
try:
|
||||
self._validate_storage_folders()
|
||||
image_path = self._model_images_folder / (model_key + ".webp")
|
||||
thumbnail = make_thumbnail(image, 256)
|
||||
thumbnail.save(image_path, format="webp")
|
||||
|
||||
except Exception as e:
|
||||
raise ModelImageFileSaveException from e
|
||||
|
||||
def get_path(self, model_key: str) -> Path:
|
||||
path = self._model_images_folder / (model_key + ".webp")
|
||||
|
||||
return path
|
||||
|
||||
def get_url(self, model_key: str) -> str | None:
|
||||
path = self.get_path(model_key)
|
||||
if not self._validate_path(path):
|
||||
return
|
||||
|
||||
return self._invoker.services.urls.get_model_image_url(model_key)
|
||||
|
||||
def delete(self, model_key: str) -> None:
|
||||
try:
|
||||
path = self.get_path(model_key)
|
||||
|
||||
if not self._validate_path(path):
|
||||
raise ModelImageFileNotFoundException
|
||||
|
||||
send2trash(path)
|
||||
|
||||
except Exception as e:
|
||||
raise ModelImageFileDeleteException from e
|
||||
|
||||
def _validate_path(self, path: Path) -> bool:
|
||||
"""Validates the path given for an image."""
|
||||
return path.exists()
|
||||
|
||||
def _validate_storage_folders(self) -> None:
|
||||
"""Checks if the required folders exist and create them if they don't"""
|
||||
self._model_images_folder.mkdir(parents=True, exist_ok=True)
|
@ -1,15 +1,11 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
|
||||
from ..config import InvokeAIAppConfig
|
||||
from ..download import DownloadQueueServiceBase
|
||||
@ -70,32 +66,3 @@ class ModelManagerServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def stop(self, invoker: Invoker) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model_by_config(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model_by_key(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model_by_attr(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
pass
|
||||
|
@ -1,14 +1,10 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
"""Implementation of ModelManagerServiceBase."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, LoadedModel, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@ -18,7 +14,7 @@ from ..download import DownloadQueueServiceBase
|
||||
from ..events.events_base import EventServiceBase
|
||||
from ..model_install import ModelInstallService, ModelInstallServiceBase
|
||||
from ..model_load import ModelLoadService, ModelLoadServiceBase
|
||||
from ..model_records import ModelRecordServiceBase, UnknownModelException
|
||||
from ..model_records import ModelRecordServiceBase
|
||||
from .model_manager_base import ModelManagerServiceBase
|
||||
|
||||
|
||||
@ -64,56 +60,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
if hasattr(service, "stop"):
|
||||
service.stop(invoker)
|
||||
|
||||
def load_model_by_config(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
return self.load.load_model(model_config, submodel_type, context_data)
|
||||
|
||||
def load_model_by_key(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
config = self.store.get_model(key)
|
||||
return self.load.load_model(config, submodel_type, context_data)
|
||||
|
||||
def load_model_by_attr(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
|
||||
|
||||
This is provided for API compatability with the get_model() method
|
||||
in the original model manager. However, note that LoadedModel is
|
||||
not the same as the original ModelInfo that ws returned.
|
||||
|
||||
:param model_name: Name of to be fetched.
|
||||
:param base_model: Base model
|
||||
:param model_type: Type of the model
|
||||
:param submodel: For main (pipeline models), the submodel to fetch
|
||||
:param context: The invocation context.
|
||||
|
||||
Exceptions: UnknownModelException -- model with this key not known
|
||||
NotImplementedException -- a model loader was not provided at initialization time
|
||||
ValueError -- more than one model matches this combination
|
||||
"""
|
||||
configs = self.store.search_by_attr(model_name, base_model, model_type)
|
||||
if len(configs) == 0:
|
||||
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
|
||||
elif len(configs) > 1:
|
||||
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
|
||||
else:
|
||||
return self.load.load_model(configs[0], submodel, context_data)
|
||||
|
||||
@classmethod
|
||||
def build_model_manager(
|
||||
cls,
|
||||
|
@ -79,6 +79,7 @@ class ModelRecordChanges(BaseModelExcludeNull):
|
||||
description="The prediction type of the model.", default=None
|
||||
)
|
||||
upcast_attention: Optional[bool] = Field(description="Whether to upcast attention.", default=None)
|
||||
config_path: Optional[str] = Field(description="Path to config file for model", default=None)
|
||||
|
||||
|
||||
class ModelRecordServiceBase(ABC):
|
||||
@ -129,6 +130,17 @@ class ModelRecordServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
|
||||
"""
|
||||
Retrieve the configuration for the indicated model.
|
||||
|
||||
:param hash: Hash of model config to be fetched.
|
||||
|
||||
Exceptions: UnknownModelException
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_models(
|
||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||
|
@ -203,6 +203,21 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
return model
|
||||
|
||||
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException("model not found")
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
return model
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Return True if a model with the indicated key exists in the databse.
|
||||
|
@ -1,7 +1,7 @@
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from PIL.Image import Image
|
||||
from torch import Tensor
|
||||
@ -13,15 +13,16 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import AnyModelRepoMetadata
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.invocations.model import ModelField
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
|
||||
"""
|
||||
@ -299,22 +300,25 @@ class ConditioningInterface(InvocationContextInterface):
|
||||
|
||||
|
||||
class ModelsInterface(InvocationContextInterface):
|
||||
def exists(self, key: str) -> bool:
|
||||
def exists(self, identifier: Union[str, "ModelField"]) -> bool:
|
||||
"""Checks if a model exists.
|
||||
|
||||
Args:
|
||||
key: The key of the model.
|
||||
identifier: The key or ModelField representing the model.
|
||||
|
||||
Returns:
|
||||
True if the model exists, False if not.
|
||||
"""
|
||||
return self._services.model_manager.store.exists(key)
|
||||
if isinstance(identifier, str):
|
||||
return self._services.model_manager.store.exists(identifier)
|
||||
|
||||
def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
return self._services.model_manager.store.exists(identifier.key)
|
||||
|
||||
def load(self, identifier: Union[str, "ModelField"], submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""Loads a model.
|
||||
|
||||
Args:
|
||||
key: The key of the model.
|
||||
identifier: The key or ModelField representing the model.
|
||||
submodel_type: The submodel of the model to get.
|
||||
|
||||
Returns:
|
||||
@ -324,9 +328,13 @@ class ModelsInterface(InvocationContextInterface):
|
||||
# The model manager emits events as it loads the model. It needs the context data to build
|
||||
# the event payloads.
|
||||
|
||||
return self._services.model_manager.load_model_by_key(
|
||||
key=key, submodel_type=submodel_type, context_data=self._data
|
||||
)
|
||||
if isinstance(identifier, str):
|
||||
model = self._services.model_manager.store.get_model(identifier)
|
||||
return self._services.model_manager.load.load_model(model, submodel_type, self._data)
|
||||
else:
|
||||
_submodel_type = submodel_type or identifier.submodel_type
|
||||
model = self._services.model_manager.store.get_model(identifier.key)
|
||||
return self._services.model_manager.load.load_model(model, _submodel_type, self._data)
|
||||
|
||||
def load_by_attrs(
|
||||
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
||||
@ -343,35 +351,29 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
An object representing the loaded model.
|
||||
"""
|
||||
return self._services.model_manager.load_model_by_attr(
|
||||
model_name=name,
|
||||
base_model=base,
|
||||
model_type=type,
|
||||
submodel=submodel_type,
|
||||
context_data=self._data,
|
||||
)
|
||||
|
||||
def get_config(self, key: str) -> AnyModelConfig:
|
||||
configs = self._services.model_manager.store.search_by_attr(model_name=name, base_model=base, model_type=type)
|
||||
if len(configs) == 0:
|
||||
raise UnknownModelException(f"No model found with name {name}, base {base}, and type {type}")
|
||||
|
||||
if len(configs) > 1:
|
||||
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
|
||||
|
||||
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
|
||||
|
||||
def get_config(self, identifier: Union[str, "ModelField"]) -> AnyModelConfig:
|
||||
"""Gets a model's config.
|
||||
|
||||
Args:
|
||||
key: The key of the model.
|
||||
identifier: The key or ModelField representing the model.
|
||||
|
||||
Returns:
|
||||
The model's config.
|
||||
"""
|
||||
return self._services.model_manager.store.get_model(key=key)
|
||||
if isinstance(identifier, str):
|
||||
return self._services.model_manager.store.get_model(identifier)
|
||||
|
||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||
"""Gets a model's metadata, if it has any.
|
||||
|
||||
Args:
|
||||
key: The key of the model.
|
||||
|
||||
Returns:
|
||||
The model's metadata, if it has any.
|
||||
"""
|
||||
return self._services.model_manager.store.get_metadata(key=key)
|
||||
return self._services.model_manager.store.get_model(identifier.key)
|
||||
|
||||
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
|
||||
"""Searches for models by path.
|
||||
|
@ -8,3 +8,8 @@ class UrlServiceBase(ABC):
|
||||
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets the URL for an image or thumbnail."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model_image_url(self, model_key: str) -> str:
|
||||
"""Gets the URL for a model image"""
|
||||
pass
|
||||
|
@ -4,8 +4,9 @@ from .urls_base import UrlServiceBase
|
||||
|
||||
|
||||
class LocalUrlService(UrlServiceBase):
|
||||
def __init__(self, base_url: str = "api/v1"):
|
||||
def __init__(self, base_url: str = "api/v1", base_url_v2: str = "api/v2"):
|
||||
self._base_url = base_url
|
||||
self._base_url_v2 = base_url_v2
|
||||
|
||||
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
image_basename = os.path.basename(image_name)
|
||||
@ -15,3 +16,6 @@ class LocalUrlService(UrlServiceBase):
|
||||
return f"{self._base_url}/images/i/{image_basename}/thumbnail"
|
||||
|
||||
return f"{self._base_url}/images/i/{image_basename}/full"
|
||||
|
||||
def get_model_image_url(self, model_key: str) -> str:
|
||||
return f"{self._base_url_v2}/models/i/{model_key}/image"
|
||||
|
@ -22,7 +22,7 @@ def generate_ti_list(
|
||||
for trigger in extract_ti_triggers_from_prompt(prompt):
|
||||
name_or_key = trigger[1:-1]
|
||||
try:
|
||||
loaded_model = context.models.load(key=name_or_key)
|
||||
loaded_model = context.models.load(name_or_key)
|
||||
model = loaded_model.model
|
||||
assert isinstance(model, TextualInversionModelRaw)
|
||||
assert loaded_model.config.base == base
|
||||
|
@ -19,7 +19,6 @@ from invokeai.app.services.model_install import (
|
||||
ModelInstallService,
|
||||
ModelInstallServiceBase,
|
||||
)
|
||||
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.model_manager import (
|
||||
@ -39,7 +38,7 @@ def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordService
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
image_files = DiskImageFileStorage(f"{app_config.output_path}/images")
|
||||
db = init_db(config=app_config, logger=logger, image_files=image_files)
|
||||
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db)
|
||||
return obj
|
||||
|
||||
|
||||
|
@ -161,6 +161,7 @@ class ModelConfigBase(BaseModel):
|
||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
cover_image: Optional[str] = Field(description="Url for image to preview model", default=None)
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
@ -310,7 +311,7 @@ class IPAdapterConfig(ModelConfigBase):
|
||||
|
||||
|
||||
class CLIPVisionDiffusersConfig(ModelConfigBase):
|
||||
"""Model config for ClipVision."""
|
||||
"""Model config for CLIPVision."""
|
||||
|
||||
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
|
||||
format: Literal[ModelFormat.Diffusers]
|
||||
|
@ -24,7 +24,7 @@ from .. import ModelLoader, ModelLoaderRegistry
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
|
||||
class LoraLoader(ModelLoader):
|
||||
class LoRALoader(ModelLoader):
|
||||
"""Class to load LoRA models."""
|
||||
|
||||
# We cheat a little bit to get access to the model base
|
||||
|
@ -23,7 +23,7 @@ from .generic_diffusers import GenericDiffusersLoader
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
class VaeLoader(GenericDiffusersLoader):
|
||||
class VAELoader(GenericDiffusersLoader):
|
||||
"""Class to load VAE models."""
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
|
@ -42,9 +42,10 @@ def install_and_load_model(
|
||||
# If the requested model is already installed, return its LoadedModel
|
||||
with contextlib.suppress(UnknownModelException):
|
||||
# TODO: Replace with wrapper call
|
||||
loaded_model: LoadedModel = model_manager.load_model_by_attr(
|
||||
configs = model_manager.store.search_by_attr(
|
||||
model_name=model_name, base_model=base_model, model_type=model_type
|
||||
)
|
||||
loaded_model: LoadedModel = model_manager.load.load_model(configs[0])
|
||||
return loaded_model
|
||||
|
||||
# Install the requested model.
|
||||
@ -53,7 +54,7 @@ def install_and_load_model(
|
||||
assert job.complete
|
||||
|
||||
try:
|
||||
loaded_model = model_manager.load_model_by_config(job.config_out)
|
||||
loaded_model = model_manager.load.load_model(job.config_out)
|
||||
return loaded_model
|
||||
except UnknownModelException as e:
|
||||
raise Exception(
|
||||
|
@ -20,7 +20,6 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadQueueService
|
||||
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
|
||||
from invokeai.app.services.model_install import ModelInstallService
|
||||
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.model_manager import (
|
||||
@ -413,7 +412,7 @@ def get_config_store() -> ModelRecordServiceSQL:
|
||||
assert output_path is not None
|
||||
image_files = DiskImageFileStorage(output_path / "images")
|
||||
db = init_db(config=config, logger=InvokeAILogger.get_logger(), image_files=image_files)
|
||||
return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||
return ModelRecordServiceSQL(db)
|
||||
|
||||
|
||||
def get_model_merger(record_store: ModelRecordServiceBase) -> ModelMerger:
|
||||
|
@ -746,6 +746,7 @@
|
||||
"delete": "Delete",
|
||||
"deleteConfig": "Delete Config",
|
||||
"deleteModel": "Delete Model",
|
||||
"deleteModelImage": "Delete Model Image",
|
||||
"deleteMsg1": "Are you sure you want to delete this model from InvokeAI?",
|
||||
"deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.",
|
||||
"description": "Description",
|
||||
@ -786,6 +787,10 @@
|
||||
"modelDeleteFailed": "Failed to delete model",
|
||||
"modelEntryDeleted": "Model Entry Deleted",
|
||||
"modelExists": "Model Exists",
|
||||
"modelImageDeleted": "Model Image Deleted",
|
||||
"modelImageDeleteFailed": "Model Image Delete Failed",
|
||||
"modelImageUpdated": "Model Image Updated",
|
||||
"modelImageUpdateFailed": "Model Image Update Failed",
|
||||
"modelLocation": "Model Location",
|
||||
"modelLocationValidationMsg": "Provide the path to a local folder where your Diffusers Model is stored",
|
||||
"modelManager": "Model Manager",
|
||||
@ -818,6 +823,7 @@
|
||||
"oliveModels": "Olives",
|
||||
"onnxModels": "Onnx",
|
||||
"path": "Path",
|
||||
"pathToConfig": "Path To Config",
|
||||
"pathToCustomConfig": "Path To Custom Config",
|
||||
"pickModelType": "Pick Model Type",
|
||||
"predictionType": "Prediction Type",
|
||||
@ -852,6 +858,7 @@
|
||||
"triggerPhrases": "Trigger Phrases",
|
||||
"typePhraseHere": "Type phrase here",
|
||||
"upcastAttention": "Upcast Attention",
|
||||
"uploadImage": "Upload Image",
|
||||
"updateModel": "Update Model",
|
||||
"useCustomConfig": "Use Custom Config",
|
||||
"useDefaultSettings": "Use Default Settings",
|
||||
@ -944,6 +951,7 @@
|
||||
"doesNotExist": "does not exist",
|
||||
"downloadWorkflow": "Download Workflow JSON",
|
||||
"edge": "Edge",
|
||||
"edit": "Edit",
|
||||
"editMode": "Edit in Workflow Editor",
|
||||
"enum": "Enum",
|
||||
"enumDescription": "Enums are values that may be one of a number of options.",
|
||||
@ -1019,6 +1027,7 @@
|
||||
"nodeTemplate": "Node Template",
|
||||
"nodeType": "Node Type",
|
||||
"noFieldsLinearview": "No fields added to Linear View",
|
||||
"noFieldsViewMode": "This workflow has no selected fields to display. View the full workflow to configure values.",
|
||||
"noFieldType": "No field type",
|
||||
"noImageFoundState": "No initial image found in state",
|
||||
"noMatchingNodes": "No matching nodes",
|
||||
@ -1806,6 +1815,7 @@
|
||||
"cursorPosition": "Cursor Position",
|
||||
"darkenOutsideSelection": "Darken Outside Selection",
|
||||
"discardAll": "Discard All",
|
||||
"discardCurrent": "Discard Current",
|
||||
"downloadAsImage": "Download As Image",
|
||||
"emptyFolder": "Empty Folder",
|
||||
"emptyTempImageFolder": "Empty Temp Image Folder",
|
||||
@ -1815,6 +1825,7 @@
|
||||
"eraseBoundingBox": "Erase Bounding Box",
|
||||
"eraser": "Eraser",
|
||||
"fillBoundingBox": "Fill Bounding Box",
|
||||
"invertBrushSizeScrollDirection": "Invert Scroll for Brush Size",
|
||||
"layer": "Layer",
|
||||
"limitStrokesToBox": "Limit Strokes to Box",
|
||||
"mask": "Mask",
|
||||
|
@ -115,7 +115,8 @@
|
||||
"safetensors": "Safetensors",
|
||||
"ai": "ia",
|
||||
"file": "File",
|
||||
"toResolve": "Da risolvere"
|
||||
"toResolve": "Da risolvere",
|
||||
"add": "Aggiungi"
|
||||
},
|
||||
"gallery": {
|
||||
"generations": "Generazioni",
|
||||
@ -153,7 +154,12 @@
|
||||
"starImage": "Immagine preferita",
|
||||
"dropToUpload": "$t(gallery.drop) per aggiornare",
|
||||
"problemDeletingImagesDesc": "Impossibile eliminare una o più immagini",
|
||||
"problemDeletingImages": "Problema durante l'eliminazione delle immagini"
|
||||
"problemDeletingImages": "Problema durante l'eliminazione delle immagini",
|
||||
"bulkDownloadRequested": "Preparazione del download",
|
||||
"bulkDownloadRequestedDesc": "La tua richiesta di download è in preparazione. L'operazione potrebbe richiedere alcuni istanti.",
|
||||
"bulkDownloadRequestFailed": "Problema durante la preparazione del download",
|
||||
"bulkDownloadStarting": "Avvio scaricamento",
|
||||
"bulkDownloadFailed": "Scaricamento fallito"
|
||||
},
|
||||
"hotkeys": {
|
||||
"keyboardShortcuts": "Tasti di scelta rapida",
|
||||
@ -505,12 +511,12 @@
|
||||
"modelSyncFailed": "Sincronizzazione modello non riuscita",
|
||||
"settings": "Impostazioni",
|
||||
"syncModels": "Sincronizza Modelli",
|
||||
"syncModelsDesc": "Se i tuoi modelli non sono sincronizzati con il back-end, puoi aggiornarli utilizzando questa opzione. Questo è generalmente utile nei casi in cui aggiorni manualmente il tuo file models.yaml o aggiungi modelli alla cartella principale di InvokeAI dopo l'avvio dell'applicazione.",
|
||||
"syncModelsDesc": "Se i tuoi modelli non sono sincronizzati con il back-end, puoi aggiornarli utilizzando questa opzione. Questo è generalmente utile nei casi in cui aggiungi modelli alla cartella principale di InvokeAI dopo l'avvio dell'applicazione.",
|
||||
"loraModels": "LoRA",
|
||||
"oliveModels": "Olive",
|
||||
"onnxModels": "ONNX",
|
||||
"noModels": "Nessun modello trovato",
|
||||
"predictionType": "Tipo di previsione (per modelli Stable Diffusion 2.x ed alcuni modelli Stable Diffusion 1.x)",
|
||||
"predictionType": "Tipo di previsione",
|
||||
"quickAdd": "Aggiunta rapida",
|
||||
"simpleModelDesc": "Fornire un percorso a un modello diffusori locale, un modello checkpoint/safetensor locale, un ID repository HuggingFace o un URL del modello checkpoint/diffusori.",
|
||||
"advanced": "Avanzate",
|
||||
@ -521,7 +527,34 @@
|
||||
"vaePrecision": "Precisione VAE",
|
||||
"noModelSelected": "Nessun modello selezionato",
|
||||
"conversionNotSupported": "Conversione non supportata",
|
||||
"configFile": "File di configurazione"
|
||||
"configFile": "File di configurazione",
|
||||
"modelName": "Nome del modello",
|
||||
"modelSettings": "Impostazioni del modello",
|
||||
"advancedImportInfo": "La scheda opzioni avanzate consente la configurazione manuale delle impostazioni del modello principale. Utilizza questa scheda solo se sei sicuro di conoscere il tipo di modello e la configurazione corretti per il modello selezionato.",
|
||||
"addAll": "Aggiungi tutto",
|
||||
"addModels": "Aggiungi modelli",
|
||||
"cancel": "Annulla",
|
||||
"edit": "Modifica",
|
||||
"imageEncoderModelId": "ID modello codificatore di immagini",
|
||||
"importQueue": "Coda di importazione",
|
||||
"modelMetadata": "Metadati del modello",
|
||||
"path": "Percorso",
|
||||
"prune": "Elimina",
|
||||
"pruneTooltip": "Elimina dalla coda le importazioni completate",
|
||||
"removeFromQueue": "Rimuovi dalla coda",
|
||||
"repoVariant": "Variante del repository",
|
||||
"scan": "Scansiona",
|
||||
"scanFolder": "Scansione cartella",
|
||||
"scanResults": "Risultati della scansione",
|
||||
"source": "Sorgente",
|
||||
"upcastAttention": "Eleva l'attenzione",
|
||||
"ztsnrTraining": "Addestramento ZTSNR",
|
||||
"typePhraseHere": "Digita la frase qui",
|
||||
"defaultSettingsSaved": "Impostazioni predefinite salvate",
|
||||
"defaultSettings": "Impostazioni predefinite",
|
||||
"metadata": "Metadati",
|
||||
"useDefaultSettings": "Usa le impostazioni predefinite",
|
||||
"triggerPhrases": "Frasi trigger"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Immagini",
|
||||
@ -603,8 +636,8 @@
|
||||
"clipSkip": "CLIP Skip",
|
||||
"aspectRatio": "Proporzioni",
|
||||
"maskAdjustmentsHeader": "Regolazioni della maschera",
|
||||
"maskBlur": "Sfocatura",
|
||||
"maskBlurMethod": "Metodo di sfocatura",
|
||||
"maskBlur": "Sfocatura maschera",
|
||||
"maskBlurMethod": "Metodo sfocatura maschera",
|
||||
"seamLowThreshold": "Basso",
|
||||
"seamHighThreshold": "Alto",
|
||||
"coherencePassHeader": "Passaggio di coerenza",
|
||||
@ -661,7 +694,8 @@
|
||||
"setToOptimalSizeTooLarge": "$t(parameters.setToOptimalSize) (potrebbe essere troppo grande)",
|
||||
"boxBlur": "Box",
|
||||
"gaussianBlur": "Gaussian",
|
||||
"remixImage": "Remixa l'immagine"
|
||||
"remixImage": "Remixa l'immagine",
|
||||
"coherenceEdgeSize": "Dimensione bordo"
|
||||
},
|
||||
"settings": {
|
||||
"models": "Modelli",
|
||||
@ -744,8 +778,8 @@
|
||||
"canceled": "Elaborazione annullata",
|
||||
"problemCopyingImageLink": "Impossibile copiare il collegamento dell'immagine",
|
||||
"uploadFailedInvalidUploadDesc": "Deve essere una singola immagine PNG o JPEG",
|
||||
"parameterSet": "Parametro impostato",
|
||||
"parameterNotSet": "Parametro non impostato",
|
||||
"parameterSet": "{{parameter}} impostato",
|
||||
"parameterNotSet": "{{parameter}} non impostato",
|
||||
"nodesLoadedFailed": "Impossibile caricare i nodi",
|
||||
"nodesSaved": "Nodi salvati",
|
||||
"nodesLoaded": "Nodi caricati",
|
||||
@ -798,7 +832,10 @@
|
||||
"problemRetrievingWorkflow": "Problema nel recupero del flusso di lavoro",
|
||||
"resetInitialImage": "Reimposta l'immagine iniziale",
|
||||
"uploadInitialImage": "Carica l'immagine iniziale",
|
||||
"problemDownloadingImage": "Impossibile scaricare l'immagine"
|
||||
"problemDownloadingImage": "Impossibile scaricare l'immagine",
|
||||
"prunedQueue": "Coda ripulita",
|
||||
"modelImportCanceled": "Importazione del modello annullata",
|
||||
"modelImportRemoved": "Importazione del modello rimossa"
|
||||
},
|
||||
"tooltip": {
|
||||
"feature": {
|
||||
@ -876,7 +913,10 @@
|
||||
"antialiasing": "Anti aliasing",
|
||||
"showResultsOn": "Mostra i risultati (attivato)",
|
||||
"showResultsOff": "Mostra i risultati (disattivato)",
|
||||
"saveMask": "Salva $t(unifiedCanvas.mask)"
|
||||
"saveMask": "Salva $t(unifiedCanvas.mask)",
|
||||
"coherenceModeGaussianBlur": "Sfocatura Gaussiana",
|
||||
"coherenceModeBoxBlur": "Sfocatura Box",
|
||||
"coherenceModeStaged": "Maschera espansa"
|
||||
},
|
||||
"accessibility": {
|
||||
"modelSelect": "Seleziona modello",
|
||||
@ -1345,7 +1385,8 @@
|
||||
"allLoRAsAdded": "Tutti i LoRA aggiunti",
|
||||
"defaultVAE": "VAE predefinito",
|
||||
"incompatibleBaseModel": "Modello base incompatibile",
|
||||
"loraAlreadyAdded": "LoRA già aggiunto"
|
||||
"loraAlreadyAdded": "LoRA già aggiunto",
|
||||
"concepts": "Concetti"
|
||||
},
|
||||
"invocationCache": {
|
||||
"disable": "Disabilita",
|
||||
@ -1698,6 +1739,25 @@
|
||||
"paragraphs": [
|
||||
"Valuta le generazioni in modo che siano più simili alle immagini con un punteggio estetico elevato, in base ai dati di addestramento."
|
||||
]
|
||||
},
|
||||
"compositingCoherenceMinDenoise": {
|
||||
"heading": "Livello minimo di riduzione del rumore",
|
||||
"paragraphs": [
|
||||
"Intensità minima di riduzione rumore per la modalità di Coerenza",
|
||||
"L'intensità minima di riduzione del rumore per la regione di coerenza durante l'inpainting o l'outpainting"
|
||||
]
|
||||
},
|
||||
"compositingMaskBlur": {
|
||||
"paragraphs": [
|
||||
"Il raggio di sfocatura della maschera."
|
||||
],
|
||||
"heading": "Sfocatura maschera"
|
||||
},
|
||||
"compositingCoherenceEdgeSize": {
|
||||
"heading": "Dimensione del bordo",
|
||||
"paragraphs": [
|
||||
"La dimensione del bordo del passaggio di coerenza."
|
||||
]
|
||||
}
|
||||
},
|
||||
"sdxl": {
|
||||
@ -1746,7 +1806,12 @@
|
||||
"scheduler": "Campionatore",
|
||||
"recallParameters": "Richiama i parametri",
|
||||
"noRecallParameters": "Nessun parametro da richiamare trovato",
|
||||
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)"
|
||||
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
|
||||
"allPrompts": "Tutti i prompt",
|
||||
"imageDimensions": "Dimensioni dell'immagine",
|
||||
"parameterSet": "Parametro {{parameter}} impostato",
|
||||
"parsingFailed": "Analisi non riuscita",
|
||||
"recallParameter": "Richiama {{label}}"
|
||||
},
|
||||
"hrf": {
|
||||
"enableHrf": "Abilita Correzione Alta Risoluzione",
|
||||
@ -1818,5 +1883,11 @@
|
||||
"image": {
|
||||
"title": "Immagine"
|
||||
}
|
||||
},
|
||||
"prompt": {
|
||||
"compatibleEmbeddings": "Incorporamenti compatibili",
|
||||
"addPromptTrigger": "Aggiungi parola chiave nel prompt",
|
||||
"noPromptTriggers": "Nessuna parola chiave disponibile",
|
||||
"noMatchingTriggers": "Nessuna parola chiave corrispondente"
|
||||
}
|
||||
}
|
||||
|
@ -52,7 +52,7 @@
|
||||
"accept": "Принять",
|
||||
"postprocessing": "Постобработка",
|
||||
"txt2img": "Текст в изображение (txt2img)",
|
||||
"linear": "Линейная обработка",
|
||||
"linear": "Линейный вид",
|
||||
"dontAskMeAgain": "Больше не спрашивать",
|
||||
"areYouSure": "Вы уверены?",
|
||||
"random": "Случайное",
|
||||
@ -117,7 +117,8 @@
|
||||
"toResolve": "Чтоб решить",
|
||||
"copy": "Копировать",
|
||||
"localSystem": "Локальная система",
|
||||
"aboutDesc": "Используя Invoke для работы? Проверьте это:"
|
||||
"aboutDesc": "Используя Invoke для работы? Проверьте это:",
|
||||
"add": "Добавить"
|
||||
},
|
||||
"gallery": {
|
||||
"generations": "Генерации",
|
||||
@ -155,7 +156,12 @@
|
||||
"noImageSelected": "Изображение не выбрано",
|
||||
"setCurrentImage": "Установить как текущее изображение",
|
||||
"starImage": "Добавить в избранное",
|
||||
"dropToUpload": "$t(gallery.drop) чтоб загрузить"
|
||||
"dropToUpload": "$t(gallery.drop) чтоб загрузить",
|
||||
"bulkDownloadFailed": "Загрузка не удалась",
|
||||
"bulkDownloadStarting": "Начало загрузки",
|
||||
"bulkDownloadRequested": "Подготовка к скачиванию",
|
||||
"bulkDownloadRequestedDesc": "Ваш запрос на скачивание готовится. Это может занять несколько минут.",
|
||||
"bulkDownloadRequestFailed": "Возникла проблема при подготовке скачивания"
|
||||
},
|
||||
"hotkeys": {
|
||||
"keyboardShortcuts": "Горячие клавиши",
|
||||
@ -504,7 +510,7 @@
|
||||
"settings": "Настройки",
|
||||
"selectModel": "Выберите модель",
|
||||
"syncModels": "Синхронизация моделей",
|
||||
"syncModelsDesc": "Если ваши модели не синхронизированы с серверной частью, вы можете обновить их, используя эту опцию. Обычно это удобно в тех случаях, когда вы вручную обновляете свой файл \"models.yaml\" или добавляете модели в корневую папку InvokeAI после загрузки приложения.",
|
||||
"syncModelsDesc": "Если ваши модели не синхронизированы с серверной частью, вы можете обновить их с помощью этой опции. Обычно это удобно в тех случаях, когда вы добавляете модели в корневую папку InvokeAI или каталог автоимпорта после загрузки приложения.",
|
||||
"modelUpdateFailed": "Не удалось обновить модель",
|
||||
"modelConversionFailed": "Не удалось сконвертировать модель",
|
||||
"modelsMergeFailed": "Не удалось выполнить слияние моделей",
|
||||
@ -513,7 +519,7 @@
|
||||
"oliveModels": "Модели Olives",
|
||||
"conversionNotSupported": "Преобразование не поддерживается",
|
||||
"noModels": "Нет моделей",
|
||||
"predictionType": "Тип прогноза (для моделей Stable Diffusion 2.x и периодических моделей Stable Diffusion 1.x)",
|
||||
"predictionType": "Тип прогноза",
|
||||
"quickAdd": "Быстрое добавление",
|
||||
"simpleModelDesc": "Укажите путь к локальной модели Diffusers , локальной модели checkpoint / safetensors, идентификатор репозитория HuggingFace или URL-адрес модели контрольной checkpoint / diffusers.",
|
||||
"advanced": "Продвинутый",
|
||||
@ -524,7 +530,33 @@
|
||||
"customConfigFileLocation": "Расположение пользовательского файла конфигурации",
|
||||
"vaePrecision": "Точность VAE",
|
||||
"noModelSelected": "Модель не выбрана",
|
||||
"configFile": "Файл конфигурации"
|
||||
"configFile": "Файл конфигурации",
|
||||
"addAll": "Добавить всё",
|
||||
"addModels": "Добавить модели",
|
||||
"cancel": "Отмена",
|
||||
"defaultSettings": "Стандартные настройки",
|
||||
"importQueue": "Импортировать очередь",
|
||||
"metadata": "Метаданные",
|
||||
"imageEncoderModelId": "ID модели-энкодера изображений",
|
||||
"typePhraseHere": "Введите фразы здесь",
|
||||
"advancedImportInfo": "Вкладка «Дополнительно» позволяет вручную настроить основные параметры модели. Используйте эту вкладку только в том случае, если вы уверены, что знаете правильный тип модели и конфигурацию выбранной модели.",
|
||||
"defaultSettingsSaved": "Стандартные настройки сохранены",
|
||||
"edit": "Редактировать",
|
||||
"path": "Путь",
|
||||
"prune": "Удалить",
|
||||
"pruneTooltip": "Удалить готовые импорты из очереди",
|
||||
"removeFromQueue": "Удалить из очереди",
|
||||
"repoVariant": "Вариант репозитория",
|
||||
"scan": "Сканировать",
|
||||
"scanFolder": "Сканировать папку",
|
||||
"scanResults": "Результаты сканирования",
|
||||
"source": "Источник",
|
||||
"triggerPhrases": "Триггерные фразы",
|
||||
"useDefaultSettings": "Использовать стандартные настройки",
|
||||
"modelMetadata": "Метаданные модели",
|
||||
"modelName": "Название модели",
|
||||
"modelSettings": "Настройки модели",
|
||||
"upcastAttention": "Внимание"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Изображения",
|
||||
@ -591,7 +623,7 @@
|
||||
"hSymmetryStep": "Шаг гор. симметрии",
|
||||
"hidePreview": "Скрыть предпросмотр",
|
||||
"imageToImage": "Изображение в изображение",
|
||||
"denoisingStrength": "Сила шумоподавления",
|
||||
"denoisingStrength": "Сила зашумления",
|
||||
"copyImage": "Скопировать изображение",
|
||||
"showPreview": "Показать предпросмотр",
|
||||
"noiseSettings": "Шум",
|
||||
@ -606,8 +638,8 @@
|
||||
"clipSkip": "CLIP Пропуск",
|
||||
"aspectRatio": "Соотношение",
|
||||
"maskAdjustmentsHeader": "Настройка маски",
|
||||
"maskBlur": "Размытие",
|
||||
"maskBlurMethod": "Метод размытия",
|
||||
"maskBlur": "Размытие маски",
|
||||
"maskBlurMethod": "Метод размытия маски",
|
||||
"seamLowThreshold": "Низкий",
|
||||
"seamHighThreshold": "Высокий",
|
||||
"coherencePassHeader": "Порог Coherence",
|
||||
@ -666,7 +698,9 @@
|
||||
"lockAspectRatio": "Заблокировать соотношение",
|
||||
"boxBlur": "Размытие прямоугольника",
|
||||
"gaussianBlur": "Размытие по Гауссу",
|
||||
"remixImage": "Ремикс изображения"
|
||||
"remixImage": "Ремикс изображения",
|
||||
"coherenceMinDenoise": "Мин. шумоподавление",
|
||||
"coherenceEdgeSize": "Размер края"
|
||||
},
|
||||
"settings": {
|
||||
"models": "Модели",
|
||||
@ -749,8 +783,8 @@
|
||||
"canceled": "Обработка отменена",
|
||||
"problemCopyingImageLink": "Не удалось скопировать ссылку на изображение",
|
||||
"uploadFailedInvalidUploadDesc": "Должно быть одно изображение в формате PNG или JPEG",
|
||||
"parameterNotSet": "Параметр не задан",
|
||||
"parameterSet": "Параметр задан",
|
||||
"parameterNotSet": "Параметр {{parameter}} не задан",
|
||||
"parameterSet": "Параметр {{parameter}} задан",
|
||||
"nodesLoaded": "Узлы загружены",
|
||||
"problemCopyingImage": "Не удается скопировать изображение",
|
||||
"nodesLoadedFailed": "Не удалось загрузить Узлы",
|
||||
@ -803,7 +837,10 @@
|
||||
"problemImportingMask": "Проблема с импортом маски",
|
||||
"problemDownloadingImage": "Не удается скачать изображение",
|
||||
"uploadInitialImage": "Загрузить начальное изображение",
|
||||
"resetInitialImage": "Сбросить начальное изображение"
|
||||
"resetInitialImage": "Сбросить начальное изображение",
|
||||
"prunedQueue": "Урезанная очередь",
|
||||
"modelImportCanceled": "Импорт модели отменен",
|
||||
"modelImportRemoved": "Импорт модели удален"
|
||||
},
|
||||
"tooltip": {
|
||||
"feature": {
|
||||
@ -1145,7 +1182,11 @@
|
||||
"reorderLinearView": "Изменить порядок линейного просмотра",
|
||||
"viewMode": "Использовать в линейном представлении",
|
||||
"editMode": "Открыть в редакторе узлов",
|
||||
"resetToDefaultValue": "Сбросить к стандартному значкнию"
|
||||
"resetToDefaultValue": "Сбросить к стандартному значкнию",
|
||||
"latentsField": "Латенты",
|
||||
"latentsCollectionDescription": "Латенты могут передаваться между узлами.",
|
||||
"latentsPolymorphicDescription": "Латенты могут передаваться между узлами.",
|
||||
"latentsFieldDescription": "Латенты могут передаваться между узлами."
|
||||
},
|
||||
"controlnet": {
|
||||
"amult": "a_mult",
|
||||
@ -1294,7 +1335,8 @@
|
||||
},
|
||||
"paramScheduler": {
|
||||
"paragraphs": [
|
||||
"Планировщик определяет, как итеративно добавлять шум к изображению или как обновлять образец на основе выходных данных модели."
|
||||
"Планировщик, используемый в процессе генерации.",
|
||||
"Каждый планировщик определяет, как итеративно добавлять шум к изображению или как обновлять образец на основе выходных данных модели."
|
||||
],
|
||||
"heading": "Планировщик"
|
||||
},
|
||||
@ -1347,7 +1389,7 @@
|
||||
"compositingCoherenceMode": {
|
||||
"heading": "Режим",
|
||||
"paragraphs": [
|
||||
"Режим прохождения когерентности."
|
||||
"Метод, используемый для создания связного изображения с вновь созданной замаскированной областью."
|
||||
]
|
||||
},
|
||||
"paramSeed": {
|
||||
@ -1365,7 +1407,7 @@
|
||||
},
|
||||
"controlNetBeginEnd": {
|
||||
"paragraphs": [
|
||||
"На каких этапах процесса шумоподавления будет применена ControlNet.",
|
||||
"Часть процесса шумоподавления, к которой будет применен адаптер контроля.",
|
||||
"ControlNet, применяемые в начале процесса, направляют композицию, а ControlNet, применяемые в конце, направляют детали."
|
||||
],
|
||||
"heading": "Процент начала/конца шага"
|
||||
@ -1381,8 +1423,8 @@
|
||||
},
|
||||
"clipSkip": {
|
||||
"paragraphs": [
|
||||
"Выберите, сколько слоев модели CLIP нужно пропустить.",
|
||||
"Некоторые модели работают лучше с определенными настройками пропуска CLIP."
|
||||
"Сколько слоев модели CLIP пропустить.",
|
||||
"Некоторые модели лучше подходят для использования с CLIP Skip."
|
||||
],
|
||||
"heading": "CLIP пропуск"
|
||||
},
|
||||
@ -1479,6 +1521,25 @@
|
||||
"paragraphs": [
|
||||
"Более высокий вес LoRA приведет к большему влиянию на конечное изображение."
|
||||
]
|
||||
},
|
||||
"compositingMaskBlur": {
|
||||
"heading": "Размытие маски",
|
||||
"paragraphs": [
|
||||
"Радиус размытия маски."
|
||||
]
|
||||
},
|
||||
"compositingCoherenceMinDenoise": {
|
||||
"heading": "Минимальное шумоподавление",
|
||||
"paragraphs": [
|
||||
"Минимальный уровень шумоподавления для режима Coherence",
|
||||
"Минимальный уровень шумоподавления для области когерентности при перерисовывании или дорисовке"
|
||||
]
|
||||
},
|
||||
"compositingCoherenceEdgeSize": {
|
||||
"heading": "Размер края",
|
||||
"paragraphs": [
|
||||
"Размер края прохода когерентности."
|
||||
]
|
||||
}
|
||||
},
|
||||
"metadata": {
|
||||
@ -1509,7 +1570,12 @@
|
||||
"steps": "Шаги",
|
||||
"scheduler": "Планировщик",
|
||||
"noRecallParameters": "Параметры для вызова не найдены",
|
||||
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)"
|
||||
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
|
||||
"parameterSet": "Параметр {{parameter}} установлен",
|
||||
"parsingFailed": "Не удалось выполнить синтаксический анализ",
|
||||
"recallParameter": "Отозвать {{label}}",
|
||||
"allPrompts": "Все запросы",
|
||||
"imageDimensions": "Размеры изображения"
|
||||
},
|
||||
"queue": {
|
||||
"status": "Статус",
|
||||
@ -1588,10 +1654,11 @@
|
||||
"denoisingStrength": "Шумоподавление",
|
||||
"refinermodel": "Модель перерисовщик",
|
||||
"posAestheticScore": "Положительная эстетическая оценка",
|
||||
"concatPromptStyle": "Объединение запроса и стиля",
|
||||
"concatPromptStyle": "Связывание запроса и стиля",
|
||||
"loading": "Загрузка...",
|
||||
"steps": "Шаги",
|
||||
"posStylePrompt": "Запрос стиля"
|
||||
"posStylePrompt": "Запрос стиля",
|
||||
"freePromptStyle": "Ручной запрос стиля"
|
||||
},
|
||||
"invocationCache": {
|
||||
"useCache": "Использовать кэш",
|
||||
@ -1678,7 +1745,8 @@
|
||||
"allLoRAsAdded": "Все LoRA добавлены",
|
||||
"defaultVAE": "Стандартное VAE",
|
||||
"incompatibleBaseModel": "Несовместимая базовая модель",
|
||||
"loraAlreadyAdded": "LoRA уже добавлена"
|
||||
"loraAlreadyAdded": "LoRA уже добавлена",
|
||||
"concepts": "Концепты"
|
||||
},
|
||||
"app": {
|
||||
"storeNotInitialized": "Магазин не инициализирован"
|
||||
@ -1696,7 +1764,7 @@
|
||||
},
|
||||
"generation": {
|
||||
"title": "Генерация",
|
||||
"conceptsTab": "Концепты",
|
||||
"conceptsTab": "LoRA",
|
||||
"modelTab": "Модель"
|
||||
},
|
||||
"advanced": {
|
||||
|
@ -5,18 +5,55 @@ import openapiTS from 'openapi-typescript';
|
||||
const OPENAPI_URL = 'http://127.0.0.1:9090/openapi.json';
|
||||
const OUTPUT_FILE = 'src/services/api/schema.ts';
|
||||
|
||||
async function main() {
|
||||
process.stdout.write(`Generating types "${OPENAPI_URL}" --> "${OUTPUT_FILE}"...`);
|
||||
const types = await openapiTS(OPENAPI_URL, {
|
||||
async function generateTypes(schema) {
|
||||
process.stdout.write(`Generating types ${OUTPUT_FILE}...`);
|
||||
const types = await openapiTS(schema, {
|
||||
exportType: true,
|
||||
transform: (schemaObject) => {
|
||||
if ('format' in schemaObject && schemaObject.format === 'binary') {
|
||||
return schemaObject.nullable ? 'Blob | null' : 'Blob';
|
||||
}
|
||||
if (schemaObject.title === 'MetadataField') {
|
||||
// This is `Record<string, never>` by default, but it actually accepts any a dict of any valid JSON value.
|
||||
return 'Record<string, unknown>';
|
||||
}
|
||||
},
|
||||
});
|
||||
fs.writeFileSync(OUTPUT_FILE, types);
|
||||
process.stdout.write(`\nOK!\r\n`);
|
||||
}
|
||||
|
||||
async function main() {
|
||||
const encoding = 'utf-8';
|
||||
|
||||
if (process.stdin.isTTY) {
|
||||
// Handle generating types with an arg (e.g. URL or path to file)
|
||||
if (process.argv.length > 3) {
|
||||
console.error('Usage: typegen.js <openapi.json>');
|
||||
process.exit(1);
|
||||
}
|
||||
if (process.argv[2]) {
|
||||
const schema = new Buffer.from(process.argv[2], encoding);
|
||||
generateTypes(schema);
|
||||
} else {
|
||||
generateTypes(OPENAPI_URL);
|
||||
}
|
||||
} else {
|
||||
// Handle generating types from stdin
|
||||
let schema = '';
|
||||
process.stdin.setEncoding(encoding);
|
||||
|
||||
process.stdin.on('readable', function () {
|
||||
const chunk = process.stdin.read();
|
||||
if (chunk !== null) {
|
||||
schema += chunk;
|
||||
}
|
||||
});
|
||||
|
||||
process.stdin.on('end', function () {
|
||||
generateTypes(JSON.parse(schema));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
main();
|
||||
|
@ -55,6 +55,7 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle
|
||||
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
|
||||
import { addControlAdapterAutoProcessorUpdateListener } from './listeners/controlAdapterAutoProcessorUpdateListener';
|
||||
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
@ -125,6 +126,7 @@ addBulkDownloadListeners(startAppListening);
|
||||
// ControlNet
|
||||
addControlNetImageProcessedListener(startAppListening);
|
||||
addControlNetAutoProcessListener(startAppListening);
|
||||
addControlAdapterAutoProcessorUpdateListener(startAppListening);
|
||||
|
||||
// Boards
|
||||
addImageAddedToBoardFulfilledListener(startAppListening);
|
||||
|
@ -38,7 +38,7 @@ export const addCanvasImageToControlNetListener = (startAppListening: AppStartLi
|
||||
type: 'image/png',
|
||||
}),
|
||||
image_category: 'control',
|
||||
is_intermediate: false,
|
||||
is_intermediate: true,
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
crop_visible: false,
|
||||
postUploadAction: {
|
||||
|
@ -48,7 +48,7 @@ export const addCanvasMaskToControlNetListener = (startAppListening: AppStartLis
|
||||
type: 'image/png',
|
||||
}),
|
||||
image_category: 'mask',
|
||||
is_intermediate: false,
|
||||
is_intermediate: true,
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
crop_visible: false,
|
||||
postUploadAction: {
|
||||
|
@ -0,0 +1,77 @@
|
||||
import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { CONTROLNET_MODEL_DEFAULT_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||
import {
|
||||
controlAdapterAutoConfigToggled,
|
||||
controlAdapterModelChanged,
|
||||
controlAdapterProcessortTypeChanged,
|
||||
selectControlAdapterById,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import type { ControlAdapterProcessorType } from 'features/controlAdapters/store/types';
|
||||
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
|
||||
export const addControlAdapterAutoProcessorUpdateListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
matcher: isAnyOf(controlAdapterModelChanged, controlAdapterAutoConfigToggled),
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
let id;
|
||||
let model;
|
||||
|
||||
const state = getState();
|
||||
const { controlAdapters: controlAdaptersState } = state;
|
||||
|
||||
if (controlAdapterModelChanged.match(action)) {
|
||||
model = action.payload.model;
|
||||
id = action.payload.id;
|
||||
|
||||
const cn = selectControlAdapterById(controlAdaptersState, id);
|
||||
if (!cn) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!isControlNetOrT2IAdapter(cn)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (controlAdapterAutoConfigToggled.match(action)) {
|
||||
id = action.payload.id;
|
||||
|
||||
const cn = selectControlAdapterById(controlAdaptersState, id);
|
||||
if (!cn) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!isControlNetOrT2IAdapter(cn)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// if they turned off autoconfig, return
|
||||
if (!cn.shouldAutoConfig) {
|
||||
return;
|
||||
}
|
||||
|
||||
model = cn.model;
|
||||
}
|
||||
|
||||
if (!model || !id) {
|
||||
return;
|
||||
}
|
||||
|
||||
let processorType: ControlAdapterProcessorType | undefined = undefined;
|
||||
const { data: modelConfig } = modelsApi.endpoints.getModelConfig.select(model.key)(state);
|
||||
|
||||
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
|
||||
if (modelConfig?.name.includes(modelSubstring)) {
|
||||
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
dispatch(
|
||||
controlAdapterProcessortTypeChanged({ id, processorType: processorType || 'none', shouldAutoConfig: true })
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
@ -101,7 +101,7 @@ export const addEnqueueRequestedCanvasListener = (startAppListening: AppStartLis
|
||||
).unwrap();
|
||||
}
|
||||
|
||||
const graph = buildCanvasGraph(state, generationMode, canvasInitImage, canvasMaskImage);
|
||||
const graph = await buildCanvasGraph(state, generationMode, canvasInitImage, canvasMaskImage);
|
||||
|
||||
log.debug({ graph: parseify(graph) }, `Canvas graph built`);
|
||||
|
||||
|
@ -20,15 +20,15 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
|
||||
if (model && model.base === 'sdxl') {
|
||||
if (action.payload.tabName === 'txt2img') {
|
||||
graph = buildLinearSDXLTextToImageGraph(state);
|
||||
graph = await buildLinearSDXLTextToImageGraph(state);
|
||||
} else {
|
||||
graph = buildLinearSDXLImageToImageGraph(state);
|
||||
graph = await buildLinearSDXLImageToImageGraph(state);
|
||||
}
|
||||
} else {
|
||||
if (action.payload.tabName === 'txt2img') {
|
||||
graph = buildLinearTextToImageGraph(state);
|
||||
graph = await buildLinearTextToImageGraph(state);
|
||||
} else {
|
||||
graph = buildLinearImageToImageGraph(state);
|
||||
graph = await buildLinearImageToImageGraph(state);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -20,7 +20,7 @@ const sx: ChakraProps['sx'] = {
|
||||
'.react-colorful__hue-pointer': colorPickerPointerStyles,
|
||||
'.react-colorful__saturation-pointer': colorPickerPointerStyles,
|
||||
'.react-colorful__alpha-pointer': colorPickerPointerStyles,
|
||||
gap: 2,
|
||||
gap: 5,
|
||||
flexDir: 'column',
|
||||
};
|
||||
|
||||
@ -39,8 +39,8 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
|
||||
<Flex sx={sx}>
|
||||
<RgbaColorPicker color={color} onChange={onChange} style={colorPickerStyles} {...rest} />
|
||||
{withNumberInput && (
|
||||
<Flex>
|
||||
<FormControl>
|
||||
<Flex gap={5}>
|
||||
<FormControl gap={0}>
|
||||
<FormLabel>{t('common.red')}</FormLabel>
|
||||
<CompositeNumberInput
|
||||
value={color.r}
|
||||
@ -52,7 +52,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
|
||||
defaultValue={90}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormControl gap={0}>
|
||||
<FormLabel>{t('common.green')}</FormLabel>
|
||||
<CompositeNumberInput
|
||||
value={color.g}
|
||||
@ -64,7 +64,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
|
||||
defaultValue={90}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormControl gap={0}>
|
||||
<FormLabel>{t('common.blue')}</FormLabel>
|
||||
<CompositeNumberInput
|
||||
value={color.b}
|
||||
@ -76,7 +76,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
|
||||
defaultValue={255}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormControl gap={0}>
|
||||
<FormLabel>{t('common.alpha')}</FormLabel>
|
||||
<CompositeNumberInput
|
||||
value={color.a}
|
||||
|
@ -29,7 +29,7 @@ import { Layer, Stage } from 'react-konva';
|
||||
import IAICanvasBoundingBoxOverlay from './IAICanvasBoundingBoxOverlay';
|
||||
import IAICanvasGrid from './IAICanvasGrid';
|
||||
import IAICanvasIntermediateImage from './IAICanvasIntermediateImage';
|
||||
import IAICanvasMaskCompositer from './IAICanvasMaskCompositer';
|
||||
import IAICanvasMaskCompositor from './IAICanvasMaskCompositor';
|
||||
import IAICanvasMaskLines from './IAICanvasMaskLines';
|
||||
import IAICanvasObjectRenderer from './IAICanvasObjectRenderer';
|
||||
import IAICanvasStagingArea from './IAICanvasStagingArea';
|
||||
@ -176,7 +176,7 @@ const IAICanvas = () => {
|
||||
</Layer>
|
||||
<Layer id="mask" visible={isMaskEnabled && !isStaging} listening={false}>
|
||||
<IAICanvasMaskLines visible={true} listening={false} />
|
||||
<IAICanvasMaskCompositer listening={false} />
|
||||
<IAICanvasMaskCompositor listening={false} />
|
||||
</Layer>
|
||||
<Layer listening={false}>
|
||||
<IAICanvasBoundingBoxOverlay />
|
||||
|
@ -16,9 +16,9 @@ const canvasMaskCompositerSelector = createMemoizedSelector(selectCanvasSlice, (
|
||||
};
|
||||
});
|
||||
|
||||
type IAICanvasMaskCompositerProps = RectConfig;
|
||||
type IAICanvasMaskCompositorProps = RectConfig;
|
||||
|
||||
const IAICanvasMaskCompositer = (props: IAICanvasMaskCompositerProps) => {
|
||||
const IAICanvasMaskCompositor = (props: IAICanvasMaskCompositorProps) => {
|
||||
const { ...rest } = props;
|
||||
|
||||
const { stageCoordinates, stageDimensions } = useAppSelector(canvasMaskCompositerSelector);
|
||||
@ -89,4 +89,4 @@ const IAICanvasMaskCompositer = (props: IAICanvasMaskCompositerProps) => {
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IAICanvasMaskCompositer);
|
||||
export default memo(IAICanvasMaskCompositor);
|
@ -5,6 +5,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { stagingAreaImageSaved } from 'features/canvas/store/actions';
|
||||
import {
|
||||
commitStagingAreaImage,
|
||||
discardStagedImage,
|
||||
discardStagedImages,
|
||||
nextStagingAreaImage,
|
||||
prevStagingAreaImage,
|
||||
@ -22,6 +23,7 @@ import {
|
||||
PiEyeBold,
|
||||
PiEyeSlashBold,
|
||||
PiFloppyDiskBold,
|
||||
PiTrashSimpleBold,
|
||||
PiXBold,
|
||||
} from 'react-icons/pi';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
@ -44,6 +46,40 @@ const selector = createMemoizedSelector(selectCanvasSlice, (canvas) => {
|
||||
};
|
||||
});
|
||||
|
||||
const ClearStagingIntermediatesIconButton = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleDiscardStagingArea = useCallback(() => {
|
||||
dispatch(discardStagedImages());
|
||||
}, [dispatch]);
|
||||
|
||||
const handleDiscardStagingImage = useCallback(() => {
|
||||
dispatch(discardStagedImage());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<IconButton
|
||||
tooltip={`${t('unifiedCanvas.discardCurrent')}`}
|
||||
aria-label={t('unifiedCanvas.discardCurrent')}
|
||||
icon={<PiXBold />}
|
||||
onClick={handleDiscardStagingImage}
|
||||
colorScheme="invokeBlue"
|
||||
fontSize={16}
|
||||
/>
|
||||
<IconButton
|
||||
tooltip={`${t('unifiedCanvas.discardAll')} (Esc)`}
|
||||
aria-label={t('unifiedCanvas.discardAll')}
|
||||
icon={<PiTrashSimpleBold />}
|
||||
onClick={handleDiscardStagingArea}
|
||||
colorScheme="error"
|
||||
fontSize={16}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
const IAICanvasStagingAreaToolbar = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { currentStagingAreaImage, shouldShowStagingImage, currentIndex, total } = useAppSelector(selector);
|
||||
@ -185,14 +221,7 @@ const IAICanvasStagingAreaToolbar = () => {
|
||||
onClick={handleSaveToGallery}
|
||||
colorScheme="invokeBlue"
|
||||
/>
|
||||
<IconButton
|
||||
tooltip={`${t('unifiedCanvas.discardAll')} (Esc)`}
|
||||
aria-label={t('unifiedCanvas.discardAll')}
|
||||
icon={<PiXBold />}
|
||||
onClick={handleDiscardStagingArea}
|
||||
colorScheme="error"
|
||||
fontSize={20}
|
||||
/>
|
||||
<ClearStagingIntermediatesIconButton />
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
);
|
||||
|
@ -18,6 +18,7 @@ import {
|
||||
setShouldAutoSave,
|
||||
setShouldCropToBoundingBoxOnSave,
|
||||
setShouldDarkenOutsideBoundingBox,
|
||||
setShouldInvertBrushSizeScrollDirection,
|
||||
setShouldRestrictStrokesToBox,
|
||||
setShouldShowCanvasDebugInfo,
|
||||
setShouldShowGrid,
|
||||
@ -40,6 +41,7 @@ const IAICanvasSettingsButtonPopover = () => {
|
||||
const shouldAutoSave = useAppSelector((s) => s.canvas.shouldAutoSave);
|
||||
const shouldCropToBoundingBoxOnSave = useAppSelector((s) => s.canvas.shouldCropToBoundingBoxOnSave);
|
||||
const shouldDarkenOutsideBoundingBox = useAppSelector((s) => s.canvas.shouldDarkenOutsideBoundingBox);
|
||||
const shouldInvertBrushSizeScrollDirection = useAppSelector((s) => s.canvas.shouldInvertBrushSizeScrollDirection);
|
||||
const shouldShowCanvasDebugInfo = useAppSelector((s) => s.canvas.shouldShowCanvasDebugInfo);
|
||||
const shouldShowGrid = useAppSelector((s) => s.canvas.shouldShowGrid);
|
||||
const shouldShowIntermediates = useAppSelector((s) => s.canvas.shouldShowIntermediates);
|
||||
@ -76,6 +78,10 @@ const IAICanvasSettingsButtonPopover = () => {
|
||||
(e: ChangeEvent<HTMLInputElement>) => dispatch(setShouldDarkenOutsideBoundingBox(e.target.checked)),
|
||||
[dispatch]
|
||||
);
|
||||
const handleChangeShouldInvertBrushSizeScrollDirection = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => dispatch(setShouldInvertBrushSizeScrollDirection(e.target.checked)),
|
||||
[dispatch]
|
||||
);
|
||||
const handleChangeShouldAutoSave = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => dispatch(setShouldAutoSave(e.target.checked)),
|
||||
[dispatch]
|
||||
@ -144,6 +150,13 @@ const IAICanvasSettingsButtonPopover = () => {
|
||||
<FormLabel>{t('unifiedCanvas.limitStrokesToBox')}</FormLabel>
|
||||
<Checkbox isChecked={shouldRestrictStrokesToBox} onChange={handleChangeShouldRestrictStrokesToBox} />
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel>{t('unifiedCanvas.invertBrushSizeScrollDirection')}</FormLabel>
|
||||
<Checkbox
|
||||
isChecked={shouldInvertBrushSizeScrollDirection}
|
||||
onChange={handleChangeShouldInvertBrushSizeScrollDirection}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel>{t('unifiedCanvas.showCanvasDebugInfo')}</FormLabel>
|
||||
<Checkbox isChecked={shouldShowCanvasDebugInfo} onChange={handleChangeShouldShowCanvasDebugInfo} />
|
||||
|
@ -15,6 +15,7 @@ const useCanvasWheel = (stageRef: MutableRefObject<Konva.Stage | null>) => {
|
||||
const stageScale = useAppSelector((s) => s.canvas.stageScale);
|
||||
const isMoveStageKeyHeld = useStore($isMoveStageKeyHeld);
|
||||
const brushSize = useAppSelector((s) => s.canvas.brushSize);
|
||||
const shouldInvertBrushSizeScrollDirection = useAppSelector((s) => s.canvas.shouldInvertBrushSizeScrollDirection);
|
||||
|
||||
return useCallback(
|
||||
(e: KonvaEventObject<WheelEvent>) => {
|
||||
@ -28,10 +29,16 @@ const useCanvasWheel = (stageRef: MutableRefObject<Konva.Stage | null>) => {
|
||||
// checking for ctrl key is pressed or not,
|
||||
// so that brush size can be controlled using ctrl + scroll up/down
|
||||
|
||||
// Invert the delta if the property is set to true
|
||||
let delta = e.evt.deltaY;
|
||||
if (shouldInvertBrushSizeScrollDirection) {
|
||||
delta = -delta;
|
||||
}
|
||||
|
||||
if ($ctrl.get() || $meta.get()) {
|
||||
// This equation was derived by fitting a curve to the desired brush sizes and deltas
|
||||
// see https://github.com/invoke-ai/InvokeAI/pull/5542#issuecomment-1915847565
|
||||
const targetDelta = Math.sign(e.evt.deltaY) * 0.7363 * Math.pow(1.0394, brushSize);
|
||||
const targetDelta = Math.sign(delta) * 0.7363 * Math.pow(1.0394, brushSize);
|
||||
// This needs to be clamped to prevent the delta from getting too large
|
||||
const finalDelta = clamp(targetDelta, -20, 20);
|
||||
// The new brush size is also clamped to prevent it from getting too large or small
|
||||
@ -67,7 +74,7 @@ const useCanvasWheel = (stageRef: MutableRefObject<Konva.Stage | null>) => {
|
||||
dispatch(setStageCoordinates(newCoordinates));
|
||||
}
|
||||
},
|
||||
[stageRef, isMoveStageKeyHeld, stageScale, dispatch, brushSize]
|
||||
[stageRef, isMoveStageKeyHeld, brushSize, dispatch, stageScale, shouldInvertBrushSizeScrollDirection]
|
||||
);
|
||||
};
|
||||
|
||||
|
@ -65,6 +65,7 @@ const initialCanvasState: CanvasState = {
|
||||
shouldAutoSave: false,
|
||||
shouldCropToBoundingBoxOnSave: false,
|
||||
shouldDarkenOutsideBoundingBox: false,
|
||||
shouldInvertBrushSizeScrollDirection: false,
|
||||
shouldLockBoundingBox: false,
|
||||
shouldPreserveMaskedArea: false,
|
||||
shouldRestrictStrokesToBox: true,
|
||||
@ -220,6 +221,9 @@ export const canvasSlice = createSlice({
|
||||
setShouldDarkenOutsideBoundingBox: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldDarkenOutsideBoundingBox = action.payload;
|
||||
},
|
||||
setShouldInvertBrushSizeScrollDirection: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldInvertBrushSizeScrollDirection = action.payload;
|
||||
},
|
||||
clearCanvasHistory: (state) => {
|
||||
state.pastLayerStates = [];
|
||||
state.futureLayerStates = [];
|
||||
@ -288,6 +292,31 @@ export const canvasSlice = createSlice({
|
||||
state.shouldShowStagingImage = true;
|
||||
state.batchIds = [];
|
||||
},
|
||||
discardStagedImage: (state) => {
|
||||
const { images, selectedImageIndex } = state.layerState.stagingArea;
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates.shift();
|
||||
}
|
||||
|
||||
if (!images.length) {
|
||||
return;
|
||||
}
|
||||
|
||||
images.splice(selectedImageIndex, 1);
|
||||
|
||||
if (selectedImageIndex >= images.length) {
|
||||
state.layerState.stagingArea.selectedImageIndex = images.length - 1;
|
||||
}
|
||||
|
||||
if (!images.length) {
|
||||
state.shouldShowStagingImage = false;
|
||||
state.shouldShowStagingOutline = false;
|
||||
}
|
||||
|
||||
state.futureLayerStates = [];
|
||||
},
|
||||
addFillRect: (state) => {
|
||||
const { boundingBoxCoordinates, boundingBoxDimensions, brushColor } = state;
|
||||
|
||||
@ -655,6 +684,7 @@ export const {
|
||||
commitColorPickerColor,
|
||||
commitStagingAreaImage,
|
||||
discardStagedImages,
|
||||
discardStagedImage,
|
||||
nextStagingAreaImage,
|
||||
prevStagingAreaImage,
|
||||
redo,
|
||||
@ -674,6 +704,7 @@ export const {
|
||||
setShouldAutoSave,
|
||||
setShouldCropToBoundingBoxOnSave,
|
||||
setShouldDarkenOutsideBoundingBox,
|
||||
setShouldInvertBrushSizeScrollDirection,
|
||||
setShouldPreserveMaskedArea,
|
||||
setShouldShowBoundingBox,
|
||||
setShouldShowCanvasDebugInfo,
|
||||
|
@ -120,6 +120,7 @@ export interface CanvasState {
|
||||
shouldAutoSave: boolean;
|
||||
shouldCropToBoundingBoxOnSave: boolean;
|
||||
shouldDarkenOutsideBoundingBox: boolean;
|
||||
shouldInvertBrushSizeScrollDirection: boolean;
|
||||
shouldLockBoundingBox: boolean;
|
||||
shouldPreserveMaskedArea: boolean;
|
||||
shouldRestrictStrokesToBox: boolean;
|
||||
|
@ -1,6 +1,12 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { CONTROLNET_MODEL_DEFAULT_PROCESSORS, CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||
import { controlAdapterAdded } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
|
||||
import type {
|
||||
ControlAdapterProcessorType,
|
||||
ControlAdapterType,
|
||||
RequiredControlAdapterProcessorNode,
|
||||
} from 'features/controlAdapters/store/types';
|
||||
import { cloneDeep } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
|
||||
import { useControlAdapterModels } from './useControlAdapterModels';
|
||||
@ -22,6 +28,29 @@ export const useAddControlAdapter = (type: ControlAdapterType) => {
|
||||
return models[0];
|
||||
}, [baseModel, models]);
|
||||
|
||||
const processor = useMemo(() => {
|
||||
let processorType;
|
||||
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
|
||||
if (firstModel?.name.includes(modelSubstring)) {
|
||||
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!processorType) {
|
||||
processorType = 'none';
|
||||
}
|
||||
|
||||
const processorNode =
|
||||
processorType === 'none'
|
||||
? (cloneDeep(CONTROLNET_PROCESSORS.none.default) as RequiredControlAdapterProcessorNode)
|
||||
: (cloneDeep(
|
||||
CONTROLNET_PROCESSORS[processorType as ControlAdapterProcessorType].default
|
||||
) as RequiredControlAdapterProcessorNode);
|
||||
|
||||
return { processorType, processorNode };
|
||||
}, [firstModel]);
|
||||
|
||||
const isDisabled = useMemo(() => !firstModel, [firstModel]);
|
||||
|
||||
const addControlAdapter = useCallback(() => {
|
||||
@ -31,10 +60,10 @@ export const useAddControlAdapter = (type: ControlAdapterType) => {
|
||||
dispatch(
|
||||
controlAdapterAdded({
|
||||
type,
|
||||
overrides: { model: firstModel },
|
||||
overrides: { model: firstModel, ...processor },
|
||||
})
|
||||
);
|
||||
}, [dispatch, firstModel, isDisabled, type]);
|
||||
}, [dispatch, firstModel, isDisabled, type, processor]);
|
||||
|
||||
return [addControlAdapter, isDisabled] as const;
|
||||
};
|
||||
|
@ -13,10 +13,7 @@ import { socketInvocationError } from 'services/events/actions';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
import { controlAdapterImageProcessed } from './actions';
|
||||
import {
|
||||
CONTROLNET_MODEL_DEFAULT_PROCESSORS as CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS,
|
||||
CONTROLNET_PROCESSORS,
|
||||
} from './constants';
|
||||
import { CONTROLNET_PROCESSORS } from './constants';
|
||||
import type {
|
||||
ControlAdapterConfig,
|
||||
ControlAdapterProcessorType,
|
||||
@ -215,25 +212,6 @@ export const controlAdaptersSlice = createSlice({
|
||||
|
||||
update.changes.processedControlImage = null;
|
||||
|
||||
let processorType: ControlAdapterProcessorType | undefined = undefined;
|
||||
|
||||
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
|
||||
// TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType
|
||||
if (model.key.includes(modelSubstring)) {
|
||||
processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (processorType) {
|
||||
update.changes.processorType = processorType;
|
||||
update.changes.processorNode = CONTROLNET_PROCESSORS[processorType]
|
||||
.default as RequiredControlAdapterProcessorNode;
|
||||
} else {
|
||||
update.changes.processorType = 'none';
|
||||
update.changes.processorNode = CONTROLNET_PROCESSORS.none.default as RequiredControlAdapterProcessorNode;
|
||||
}
|
||||
|
||||
caAdapter.updateOne(state, update);
|
||||
},
|
||||
controlAdapterWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => {
|
||||
@ -298,17 +276,19 @@ export const controlAdaptersSlice = createSlice({
|
||||
action: PayloadAction<{
|
||||
id: string;
|
||||
processorType: ControlAdapterProcessorType;
|
||||
shouldAutoConfig?: boolean;
|
||||
}>
|
||||
) => {
|
||||
const { id, processorType } = action.payload;
|
||||
const { id, processorType, shouldAutoConfig = false } = action.payload;
|
||||
const cn = selectControlAdapterById(state, id);
|
||||
if (!cn || !isControlNetOrT2IAdapter(cn)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const processorNode = cloneDeep(
|
||||
CONTROLNET_PROCESSORS[processorType].default
|
||||
) as RequiredControlAdapterProcessorNode;
|
||||
const processorNode =
|
||||
processorType === 'none'
|
||||
? (cloneDeep(CONTROLNET_PROCESSORS.none.default) as RequiredControlAdapterProcessorNode)
|
||||
: (cloneDeep(CONTROLNET_PROCESSORS[processorType].default) as RequiredControlAdapterProcessorNode);
|
||||
|
||||
caAdapter.updateOne(state, {
|
||||
id,
|
||||
@ -316,7 +296,7 @@ export const controlAdaptersSlice = createSlice({
|
||||
processorType,
|
||||
processedControlImage: null,
|
||||
processorNode,
|
||||
shouldAutoConfig: false,
|
||||
shouldAutoConfig,
|
||||
},
|
||||
});
|
||||
},
|
||||
@ -337,28 +317,6 @@ export const controlAdaptersSlice = createSlice({
|
||||
changes: { shouldAutoConfig: !cn.shouldAutoConfig },
|
||||
};
|
||||
|
||||
if (update.changes.shouldAutoConfig) {
|
||||
// manage the processor for the user
|
||||
let processorType: ControlAdapterProcessorType | undefined = undefined;
|
||||
|
||||
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
|
||||
// TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType
|
||||
if (cn.model?.key.includes(modelSubstring)) {
|
||||
processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (processorType) {
|
||||
update.changes.processorType = processorType;
|
||||
update.changes.processorNode = CONTROLNET_PROCESSORS[processorType]
|
||||
.default as RequiredControlAdapterProcessorNode;
|
||||
} else {
|
||||
update.changes.processorType = 'none';
|
||||
update.changes.processorNode = CONTROLNET_PROCESSORS.none.default as RequiredControlAdapterProcessorNode;
|
||||
}
|
||||
}
|
||||
|
||||
caAdapter.updateOne(state, update);
|
||||
},
|
||||
controlAdaptersReset: () => {
|
||||
|
@ -6,7 +6,7 @@ const AutoAddIcon = () => {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Flex position="absolute" insetInlineEnd={0} top={0} p={1}>
|
||||
<Badge variant="solid" bg="invokeBlue.500">
|
||||
<Badge variant="solid" bg="invokeBlue.400">
|
||||
{t('common.auto')}
|
||||
</Badge>
|
||||
</Flex>
|
||||
|
@ -173,8 +173,8 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
|
||||
w="full"
|
||||
maxW="full"
|
||||
borderBottomRadius="base"
|
||||
bg={isSelected ? 'invokeBlue.500' : 'base.600'}
|
||||
color={isSelected ? 'base.50' : 'base.100'}
|
||||
bg={isSelected ? 'invokeBlue.400' : 'base.600'}
|
||||
color={isSelected ? 'base.800' : 'base.100'}
|
||||
lineHeight="short"
|
||||
fontSize="xs"
|
||||
>
|
||||
@ -193,6 +193,7 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
|
||||
overflow="hidden"
|
||||
textOverflow="ellipsis"
|
||||
noOfLines={1}
|
||||
color="inherit"
|
||||
/>
|
||||
<EditableInput sx={editableInputStyles} />
|
||||
</Editable>
|
||||
|
@ -109,8 +109,8 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
|
||||
w="full"
|
||||
maxW="full"
|
||||
borderBottomRadius="base"
|
||||
bg={isSelected ? 'invokeBlue.500' : 'base.600'}
|
||||
color={isSelected ? 'base.50' : 'base.100'}
|
||||
bg={isSelected ? 'invokeBlue.400' : 'base.600'}
|
||||
color={isSelected ? 'base.800' : 'base.100'}
|
||||
lineHeight="short"
|
||||
fontSize="xs"
|
||||
fontWeight={isSelected ? 'bold' : 'normal'}
|
||||
|
@ -15,7 +15,7 @@ export const MetadataItemView = memo(
|
||||
return (
|
||||
<Flex gap={2}>
|
||||
{onRecall && <RecallButton label={label} onClick={onRecall} isDisabled={isDisabled} />}
|
||||
<Flex direction={direction}>
|
||||
<Flex direction={direction} fontSize="sm">
|
||||
<Text fontWeight="semibold" whiteSpace="pre-wrap" pr={2}>
|
||||
{label}:
|
||||
</Text>
|
||||
|
@ -0,0 +1,27 @@
|
||||
import { Box, Image } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
|
||||
type Props = {
|
||||
image_url?: string;
|
||||
};
|
||||
|
||||
const ModelImage = ({ image_url }: Props) => {
|
||||
if (!image_url) {
|
||||
return <Box height="50px" minWidth="50px" />;
|
||||
}
|
||||
|
||||
return (
|
||||
<Image
|
||||
src={image_url}
|
||||
objectFit="cover"
|
||||
objectPosition="50% 50%"
|
||||
height="50px"
|
||||
width="50px"
|
||||
minHeight="50px"
|
||||
minWidth="50px"
|
||||
borderRadius="base"
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default typedMemo(ModelImage);
|
@ -22,6 +22,8 @@ import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||
import { useDeleteModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
import ModelImage from './ModelImage';
|
||||
|
||||
type ModelListItemProps = {
|
||||
model: AnyModelConfig;
|
||||
};
|
||||
@ -73,6 +75,7 @@ const ModelListItem = (props: ModelListItemProps) => {
|
||||
|
||||
return (
|
||||
<Flex gap={2} alignItems="center" w="full">
|
||||
<ModelImage image_url={model.cover_image || ''} />
|
||||
<Flex
|
||||
as={Button}
|
||||
isChecked={isSelected}
|
||||
|
@ -0,0 +1,134 @@
|
||||
import { Box, Button, IconButton, Image } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { useDropzone } from 'react-dropzone';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowCounterClockwiseBold, PiUploadSimpleBold } from 'react-icons/pi';
|
||||
import { useDeleteModelImageMutation, useUpdateModelImageMutation } from 'services/api/endpoints/models';
|
||||
|
||||
type Props = {
|
||||
model_key: string | null;
|
||||
model_image?: string | null;
|
||||
};
|
||||
|
||||
const ModelImageUpload = ({ model_key, model_image }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const [image, setImage] = useState<string | null>(model_image || null);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [updateModelImage] = useUpdateModelImageMutation();
|
||||
const [deleteModelImage] = useDeleteModelImageMutation();
|
||||
|
||||
const onDropAccepted = useCallback(
|
||||
(files: File[]) => {
|
||||
const file = files[0];
|
||||
|
||||
if (!file || !model_key) {
|
||||
return;
|
||||
}
|
||||
|
||||
updateModelImage({ key: model_key, image: file })
|
||||
.unwrap()
|
||||
.then(() => {
|
||||
setImage(URL.createObjectURL(file));
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelImageUpdated'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelImageUpdateFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
});
|
||||
},
|
||||
[dispatch, model_key, t, updateModelImage]
|
||||
);
|
||||
|
||||
const handleResetImage = useCallback(() => {
|
||||
if (!model_key) {
|
||||
return;
|
||||
}
|
||||
setImage(null);
|
||||
deleteModelImage(model_key)
|
||||
.unwrap()
|
||||
.then(() => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelImageDeleted'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelImageDeleteFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
});
|
||||
}, [dispatch, model_key, t, deleteModelImage]);
|
||||
|
||||
const { getInputProps, getRootProps } = useDropzone({
|
||||
accept: { 'image/png': ['.png'], 'image/jpeg': ['.jpg', '.jpeg', '.png'] },
|
||||
onDropAccepted,
|
||||
noDrag: true,
|
||||
multiple: false,
|
||||
});
|
||||
|
||||
if (image) {
|
||||
return (
|
||||
<Box position="relative">
|
||||
<Image
|
||||
src={image}
|
||||
objectFit="cover"
|
||||
objectPosition="50% 50%"
|
||||
height="100px"
|
||||
width="100px"
|
||||
minWidth="100px"
|
||||
borderRadius="base"
|
||||
/>
|
||||
<IconButton
|
||||
position="absolute"
|
||||
top="1"
|
||||
right="1"
|
||||
onClick={handleResetImage}
|
||||
aria-label={t('modelManager.deleteModelImage')}
|
||||
tooltip={t('modelManager.deleteModelImage')}
|
||||
icon={<PiArrowCounterClockwiseBold size={16} />}
|
||||
size="sm"
|
||||
variant="link"
|
||||
_hover={{ color: 'base.100' }}
|
||||
/>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Button leftIcon={<PiUploadSimpleBold />} {...getRootProps()} pointerEvents="auto">
|
||||
{t('modelManager.uploadImage')}
|
||||
</Button>
|
||||
<input {...getInputProps()} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default typedMemo(ModelImageUpload);
|
@ -4,6 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import ModelImageUpload from './Fields/ModelImageUpload';
|
||||
import { ModelMetadata } from './Metadata/ModelMetadata';
|
||||
import { ModelAttrView } from './ModelAttrView';
|
||||
import { ModelEdit } from './ModelEdit';
|
||||
@ -25,19 +26,22 @@ export const Model = () => {
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex flexDir="column" gap={1} p={2}>
|
||||
<Heading as="h2" fontSize="lg">
|
||||
{data.name}
|
||||
</Heading>
|
||||
<Flex alignItems="center" justifyContent="space-between" gap="4" paddingRight="5">
|
||||
<Flex flexDir="column" gap={1} p={2}>
|
||||
<Heading as="h2" fontSize="lg">
|
||||
{data.name}
|
||||
</Heading>
|
||||
|
||||
{data.source && (
|
||||
<Text variant="subtext">
|
||||
{t('modelManager.source')}: {data?.source}
|
||||
</Text>
|
||||
)}
|
||||
<Box mt="4">
|
||||
<ModelAttrView label="Description" value={data.description} />
|
||||
</Box>
|
||||
{data.source && (
|
||||
<Text variant="subtext">
|
||||
{t('modelManager.source')}: {data?.source}
|
||||
</Text>
|
||||
)}
|
||||
<Box mt="4">
|
||||
<ModelAttrView label="Description" value={data.description} />
|
||||
</Box>
|
||||
</Flex>
|
||||
<ModelImageUpload model_key={selectedModelKey} model_image={data.cover_image} />
|
||||
</Flex>
|
||||
|
||||
<Tabs mt="4" h="100%">
|
||||
|
@ -129,7 +129,7 @@ export const ModelEdit = () => {
|
||||
</Flex>
|
||||
|
||||
<Flex flexDir="column" gap={3} mt="4">
|
||||
<Flex>
|
||||
<Flex gap="4" alignItems="center">
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
||||
<Textarea fontSize="md" resize="none" {...register('description')} />
|
||||
@ -145,20 +145,32 @@ export const ModelEdit = () => {
|
||||
</FormControl>
|
||||
</Flex>
|
||||
{data.type === 'main' && data.format === 'checkpoint' && (
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
<ModelVariantSelect control={control} />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
||||
<PredictionTypeSelect control={control} />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
||||
<Checkbox {...register('upcast_attention')} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
|
||||
<Input
|
||||
{...register('config_path', {
|
||||
validate: (value) => (value && value.trim().length > 3) || 'Must be at least 3 characters',
|
||||
})}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
<ModelVariantSelect control={control} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
||||
<PredictionTypeSelect control={control} />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
||||
<Checkbox {...register('upcast_attention')} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
</Flex>
|
||||
</form>
|
||||
|
@ -90,21 +90,26 @@ export const ModelView = () => {
|
||||
<ModelAttrView label={t('common.format')} value={modelData.format} />
|
||||
<ModelAttrView label={t('modelManager.path')} value={modelData.path} />
|
||||
</Flex>
|
||||
{modelData.type === 'main' && (
|
||||
|
||||
{modelData.type === 'main' && modelData.format === 'diffusers' && modelData.repo_variant && (
|
||||
<Flex gap={2}>
|
||||
{modelData.format === 'diffusers' && modelData.repo_variant && (
|
||||
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
|
||||
)}
|
||||
{modelData.format === 'checkpoint' && (
|
||||
<>
|
||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config_path} />
|
||||
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
|
||||
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
|
||||
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
|
||||
</>
|
||||
)}
|
||||
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
|
||||
</Flex>
|
||||
)}
|
||||
|
||||
{modelData.type === 'main' && modelData.format === 'checkpoint' && (
|
||||
<>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config_path} />
|
||||
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
|
||||
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
|
||||
{modelData.type === 'ip_adapter' && (
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={modelData.image_encoder_model_id} />
|
||||
@ -112,9 +117,11 @@ export const ModelView = () => {
|
||||
)}
|
||||
</Flex>
|
||||
</Box>
|
||||
<Box layerStyle="second" borderRadius="base" p={3}>
|
||||
<DefaultSettings />
|
||||
</Box>
|
||||
{modelData.type === 'main' && (
|
||||
<Box layerStyle="second" borderRadius="base" p={3}>
|
||||
<DefaultSettings />
|
||||
</Box>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
@ -0,0 +1,55 @@
|
||||
import { Button, Flex, Image, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { workflowModeChanged } from 'features/nodes/store/workflowSlice';
|
||||
import InvokeLogoSVG from 'public/assets/images/invoke-symbol-wht-lrg.svg';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const EmptyState = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
dispatch(workflowModeChanged('edit'));
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
userSelect: 'none',
|
||||
}}
|
||||
>
|
||||
<Flex
|
||||
sx={{
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
borderRadius: 'base',
|
||||
flexDir: 'column',
|
||||
gap: 5,
|
||||
maxW: '230px',
|
||||
margin: '0 auto',
|
||||
}}
|
||||
>
|
||||
<Image
|
||||
src={InvokeLogoSVG}
|
||||
alt="invoke-ai-logo"
|
||||
opacity={0.2}
|
||||
mixBlendMode="overlay"
|
||||
w={16}
|
||||
h={16}
|
||||
minW={16}
|
||||
minH={16}
|
||||
userSelect="none"
|
||||
/>
|
||||
<Text textAlign="center" fontSize="md">
|
||||
{t('nodes.noFieldsViewMode')}
|
||||
</Text>
|
||||
<Button colorScheme="invokeBlue" onClick={onClick}>
|
||||
{t('nodes.edit')}
|
||||
</Button>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
@ -7,6 +7,7 @@ import { selectWorkflowSlice } from 'features/nodes/store/workflowSlice';
|
||||
import { t } from 'i18next';
|
||||
import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
|
||||
|
||||
import { EmptyState } from './EmptyState';
|
||||
import WorkflowField from './WorkflowField';
|
||||
|
||||
const selector = createMemoizedSelector(selectWorkflowSlice, (workflow) => {
|
||||
@ -30,7 +31,7 @@ export const WorkflowViewMode = () => {
|
||||
<WorkflowField key={`${nodeId}.${fieldName}`} nodeId={nodeId} fieldName={fieldName} />
|
||||
))
|
||||
) : (
|
||||
<IAINoContentFallback label={t('nodes.noFieldsLinearview')} icon={null} />
|
||||
<EmptyState />
|
||||
)}
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
|
@ -12,17 +12,17 @@ const WorkflowPanel = () => {
|
||||
<Flex layerStyle="first" flexDir="column" w="full" h="full" borderRadius="base" p={2} gap={2}>
|
||||
<Tabs variant="line" display="flex" w="full" h="full" flexDir="column">
|
||||
<TabList>
|
||||
<Tab>{t('common.details')}</Tab>
|
||||
<Tab>{t('common.linear')}</Tab>
|
||||
<Tab>{t('common.details')}</Tab>
|
||||
<Tab>JSON</Tab>
|
||||
</TabList>
|
||||
|
||||
<TabPanels>
|
||||
<TabPanel>
|
||||
<WorkflowGeneralTab />
|
||||
<WorkflowLinearTab />
|
||||
</TabPanel>
|
||||
<TabPanel>
|
||||
<WorkflowLinearTab />
|
||||
<WorkflowGeneralTab />
|
||||
</TabPanel>
|
||||
<TabPanel>
|
||||
<WorkflowJSONTab />
|
||||
|
@ -55,8 +55,22 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
|
||||
|
||||
// #region Model-related schemas
|
||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
||||
const zSubModelType = z.enum([
|
||||
'unet',
|
||||
'text_encoder',
|
||||
'text_encoder_2',
|
||||
'tokenizer',
|
||||
'tokenizer_2',
|
||||
'vae',
|
||||
'vae_decoder',
|
||||
'vae_encoder',
|
||||
'scheduler',
|
||||
'safety_checker',
|
||||
]);
|
||||
|
||||
const zModelIdentifier = z.object({
|
||||
key: z.string().min(1),
|
||||
submodel_type: zSubModelType.nullish(),
|
||||
});
|
||||
export const isModelIdentifier = (field: unknown): field is ModelIdentifier =>
|
||||
zModelIdentifier.safeParse(field).success;
|
||||
|
@ -1,18 +1,22 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { selectValidControlNets } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { omit } from 'lodash-es';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import type {
|
||||
CollectInvocation,
|
||||
ControlField,
|
||||
ControlNetInvocation,
|
||||
CoreMetadataInvocation,
|
||||
NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
import { isControlNetModelConfig } from 'services/api/types';
|
||||
|
||||
import { CONTROL_NET_COLLECT } from './constants';
|
||||
import { upsertMetadata } from './metadata';
|
||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||
|
||||
export const addControlNetToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
|
||||
export const addControlNetToLinearGraph = async (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string
|
||||
): Promise<void> => {
|
||||
const validControlNets = selectValidControlNets(state.controlAdapters).filter(
|
||||
(ca) => ca.model?.base === state.generation.model?.base
|
||||
);
|
||||
@ -39,7 +43,7 @@ export const addControlNetToLinearGraph = (state: RootState, graph: NonNullableG
|
||||
},
|
||||
});
|
||||
|
||||
validControlNets.forEach((controlNet) => {
|
||||
validControlNets.forEach(async (controlNet) => {
|
||||
if (!controlNet.model) {
|
||||
return;
|
||||
}
|
||||
@ -85,7 +89,17 @@ export const addControlNetToLinearGraph = (state: RootState, graph: NonNullableG
|
||||
|
||||
graph.nodes[controlNetNode.id] = controlNetNode as ControlNetInvocation;
|
||||
|
||||
controlNetMetadata.push(omit(controlNetNode, ['id', 'type', 'is_intermediate']) as ControlField);
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isControlNetModelConfig);
|
||||
|
||||
controlNetMetadata.push({
|
||||
control_model: getModelMetadataField(modelConfig),
|
||||
control_weight: weight,
|
||||
control_mode: controlMode,
|
||||
begin_step_percent: beginStepPct,
|
||||
end_step_percent: endStepPct,
|
||||
resize_mode: resizeMode,
|
||||
image: controlNetNode.image,
|
||||
});
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: controlNetNode.id, field: 'control' },
|
||||
|
@ -1,18 +1,22 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { omit } from 'lodash-es';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import type {
|
||||
CollectInvocation,
|
||||
CoreMetadataInvocation,
|
||||
IPAdapterInvocation,
|
||||
IPAdapterMetadataField,
|
||||
NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
import { isIPAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
import { IP_ADAPTER_COLLECT } from './constants';
|
||||
import { upsertMetadata } from './metadata';
|
||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||
|
||||
export const addIPAdapterToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
|
||||
export const addIPAdapterToLinearGraph = async (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string
|
||||
): Promise<void> => {
|
||||
const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter(
|
||||
(ca) => ca.model?.base === state.generation.model?.base
|
||||
);
|
||||
@ -35,7 +39,7 @@ export const addIPAdapterToLinearGraph = (state: RootState, graph: NonNullableGr
|
||||
|
||||
const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = [];
|
||||
|
||||
validIPAdapters.forEach((ipAdapter) => {
|
||||
validIPAdapters.forEach(async (ipAdapter) => {
|
||||
if (!ipAdapter.model) {
|
||||
return;
|
||||
}
|
||||
@ -58,9 +62,17 @@ export const addIPAdapterToLinearGraph = (state: RootState, graph: NonNullableGr
|
||||
return;
|
||||
}
|
||||
|
||||
graph.nodes[ipAdapterNode.id] = ipAdapterNode as IPAdapterInvocation;
|
||||
graph.nodes[ipAdapterNode.id] = ipAdapterNode;
|
||||
|
||||
ipAdapterMetdata.push(omit(ipAdapterNode, ['id', 'type', 'is_intermediate']) as IPAdapterMetadataField);
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isIPAdapterModelConfig);
|
||||
|
||||
ipAdapterMetdata.push({
|
||||
weight: weight,
|
||||
ip_adapter_model: getModelMetadataField(modelConfig),
|
||||
begin_step_percent: beginStepPct,
|
||||
end_step_percent: endStepPct,
|
||||
image: ipAdapterNode.image,
|
||||
});
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
|
||||
|
@ -1,16 +1,22 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { filter, size } from 'lodash-es';
|
||||
import type { CoreMetadataInvocation, LoraLoaderInvocation, NonNullableGraph } from 'services/api/types';
|
||||
import {
|
||||
type CoreMetadataInvocation,
|
||||
isLoRAModelConfig,
|
||||
type LoRALoaderInvocation,
|
||||
type NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
|
||||
import { CLIP_SKIP, LORA_LOADER, MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, POSITIVE_CONDITIONING } from './constants';
|
||||
import { upsertMetadata } from './metadata';
|
||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||
|
||||
export const addLoRAsToGraph = (
|
||||
export const addLoRAsToGraph = async (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string,
|
||||
modelLoaderNodeId: string = MAIN_MODEL_LOADER
|
||||
): void => {
|
||||
): Promise<void> => {
|
||||
/**
|
||||
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
|
||||
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
|
||||
@ -39,12 +45,12 @@ export const addLoRAsToGraph = (
|
||||
let currentLoraIndex = 0;
|
||||
const loraMetadata: CoreMetadataInvocation['loras'] = [];
|
||||
|
||||
enabledLoRAs.forEach((lora) => {
|
||||
enabledLoRAs.forEach(async (lora) => {
|
||||
const { weight } = lora;
|
||||
const { key } = lora.model;
|
||||
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
|
||||
|
||||
const loraLoaderNode: LoraLoaderInvocation = {
|
||||
const loraLoaderNode: LoRALoaderInvocation = {
|
||||
type: 'lora_loader',
|
||||
id: currentLoraNodeId,
|
||||
is_intermediate: true,
|
||||
@ -52,8 +58,10 @@ export const addLoRAsToGraph = (
|
||||
weight,
|
||||
};
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
|
||||
|
||||
loraMetadata.push({
|
||||
model: { key },
|
||||
model: getModelMetadataField(modelConfig),
|
||||
weight,
|
||||
});
|
||||
|
||||
|
@ -1,6 +1,12 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { filter, size } from 'lodash-es';
|
||||
import type { CoreMetadataInvocation, NonNullableGraph, SDXLLoraLoaderInvocation } from 'services/api/types';
|
||||
import {
|
||||
type CoreMetadataInvocation,
|
||||
isLoRAModelConfig,
|
||||
type NonNullableGraph,
|
||||
type SDXLLoRALoaderInvocation,
|
||||
} from 'services/api/types';
|
||||
|
||||
import {
|
||||
LORA_LOADER,
|
||||
@ -10,14 +16,14 @@ import {
|
||||
SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { upsertMetadata } from './metadata';
|
||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||
|
||||
export const addSDXLLoRAsToGraph = (
|
||||
export const addSDXLLoRAsToGraph = async (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string,
|
||||
modelLoaderNodeId: string = SDXL_MODEL_LOADER
|
||||
): void => {
|
||||
): Promise<void> => {
|
||||
/**
|
||||
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
|
||||
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
|
||||
@ -55,12 +61,12 @@ export const addSDXLLoRAsToGraph = (
|
||||
let lastLoraNodeId = '';
|
||||
let currentLoraIndex = 0;
|
||||
|
||||
enabledLoRAs.forEach((lora) => {
|
||||
enabledLoRAs.forEach(async (lora) => {
|
||||
const { weight } = lora;
|
||||
const { key } = lora.model;
|
||||
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
|
||||
|
||||
const loraLoaderNode: SDXLLoraLoaderInvocation = {
|
||||
const loraLoaderNode: SDXLLoRALoaderInvocation = {
|
||||
type: 'sdxl_lora_loader',
|
||||
id: currentLoraNodeId,
|
||||
is_intermediate: true,
|
||||
@ -68,7 +74,9 @@ export const addSDXLLoRAsToGraph = (
|
||||
weight,
|
||||
};
|
||||
|
||||
loraMetadata.push({ model: { key }, weight });
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
|
||||
|
||||
loraMetadata.push({ model: getModelMetadataField(modelConfig), weight });
|
||||
|
||||
// add to graph
|
||||
graph.nodes[currentLoraNodeId] = loraLoaderNode;
|
||||
|
@ -1,9 +1,11 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type {
|
||||
CreateDenoiseMaskInvocation,
|
||||
ImageDTO,
|
||||
NonNullableGraph,
|
||||
SeamlessModeInvocation,
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import {
|
||||
type CreateDenoiseMaskInvocation,
|
||||
type ImageDTO,
|
||||
isRefinerMainModelModelConfig,
|
||||
type NonNullableGraph,
|
||||
type SeamlessModeInvocation,
|
||||
} from 'services/api/types';
|
||||
|
||||
import {
|
||||
@ -25,16 +27,16 @@ import {
|
||||
SDXL_REFINER_SEAMLESS,
|
||||
} from './constants';
|
||||
import { getSDXLStylePrompts } from './graphBuilderUtils';
|
||||
import { upsertMetadata } from './metadata';
|
||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||
|
||||
export const addSDXLRefinerToGraph = (
|
||||
export const addSDXLRefinerToGraph = async (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string,
|
||||
modelLoaderNodeId?: string,
|
||||
canvasInitImage?: ImageDTO,
|
||||
canvasMaskImage?: ImageDTO
|
||||
): void => {
|
||||
): Promise<void> => {
|
||||
const {
|
||||
refinerModel,
|
||||
refinerPositiveAestheticScore,
|
||||
@ -55,9 +57,10 @@ export const addSDXLRefinerToGraph = (
|
||||
const fp32 = vaePrecision === 'fp32';
|
||||
|
||||
const isUsingScaledDimensions = ['auto', 'manual'].includes(boundingBoxScaleMethod);
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(refinerModel.key, isRefinerMainModelModelConfig);
|
||||
|
||||
upsertMetadata(graph, {
|
||||
refiner_model: refinerModel,
|
||||
refiner_model: getModelMetadataField(modelConfig),
|
||||
refiner_positive_aesthetic_score: refinerPositiveAestheticScore,
|
||||
refiner_negative_aesthetic_score: refinerNegativeAestheticScore,
|
||||
refiner_cfg_scale: refinerCFGScale,
|
||||
|
@ -1,18 +1,22 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { selectValidT2IAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { omit } from 'lodash-es';
|
||||
import type {
|
||||
CollectInvocation,
|
||||
CoreMetadataInvocation,
|
||||
NonNullableGraph,
|
||||
T2IAdapterField,
|
||||
T2IAdapterInvocation,
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import {
|
||||
type CollectInvocation,
|
||||
type CoreMetadataInvocation,
|
||||
isT2IAdapterModelConfig,
|
||||
type NonNullableGraph,
|
||||
type T2IAdapterInvocation,
|
||||
} from 'services/api/types';
|
||||
|
||||
import { T2I_ADAPTER_COLLECT } from './constants';
|
||||
import { upsertMetadata } from './metadata';
|
||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||
|
||||
export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
|
||||
export const addT2IAdaptersToLinearGraph = async (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string
|
||||
): Promise<void> => {
|
||||
const validT2IAdapters = selectValidT2IAdapters(state.controlAdapters).filter(
|
||||
(ca) => ca.model?.base === state.generation.model?.base
|
||||
);
|
||||
@ -33,9 +37,9 @@ export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullable
|
||||
},
|
||||
});
|
||||
|
||||
const t2iAdapterMetdata: CoreMetadataInvocation['t2iAdapters'] = [];
|
||||
const t2iAdapterMetadata: CoreMetadataInvocation['t2iAdapters'] = [];
|
||||
|
||||
validT2IAdapters.forEach((t2iAdapter) => {
|
||||
validT2IAdapters.forEach(async (t2iAdapter) => {
|
||||
if (!t2iAdapter.model) {
|
||||
return;
|
||||
}
|
||||
@ -77,9 +81,18 @@ export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullable
|
||||
return;
|
||||
}
|
||||
|
||||
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode as T2IAdapterInvocation;
|
||||
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode;
|
||||
|
||||
t2iAdapterMetdata.push(omit(t2iAdapterNode, ['id', 'type', 'is_intermediate']) as T2IAdapterField);
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(t2iAdapter.model.key, isT2IAdapterModelConfig);
|
||||
|
||||
t2iAdapterMetadata.push({
|
||||
begin_step_percent: beginStepPct,
|
||||
end_step_percent: endStepPct,
|
||||
resize_mode: resizeMode,
|
||||
t2i_adapter_model: getModelMetadataField(modelConfig),
|
||||
weight: weight,
|
||||
image: t2iAdapterNode.image,
|
||||
});
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },
|
||||
@ -90,6 +103,6 @@ export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullable
|
||||
});
|
||||
});
|
||||
|
||||
upsertMetadata(graph, { t2iAdapters: t2iAdapterMetdata });
|
||||
upsertMetadata(graph, { t2iAdapters: t2iAdapterMetadata });
|
||||
}
|
||||
};
|
||||
|
@ -1,5 +1,7 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { NonNullableGraph } from 'services/api/types';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import type { ModelMetadataField, NonNullableGraph } from 'services/api/types';
|
||||
import { isVAEModelConfig } from 'services/api/types';
|
||||
|
||||
import {
|
||||
CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
||||
@ -23,13 +25,13 @@ import {
|
||||
TEXT_TO_IMAGE_GRAPH,
|
||||
VAE_LOADER,
|
||||
} from './constants';
|
||||
import { upsertMetadata } from './metadata';
|
||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||
|
||||
export const addVAEToGraph = (
|
||||
export const addVAEToGraph = async (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
modelLoaderNodeId: string = MAIN_MODEL_LOADER
|
||||
): void => {
|
||||
): Promise<void> => {
|
||||
const { vae, seamlessXAxis, seamlessYAxis } = state.generation;
|
||||
const { boundingBoxScaleMethod } = state.canvas;
|
||||
const { refinerModel } = state.sdxl;
|
||||
@ -149,6 +151,8 @@ export const addVAEToGraph = (
|
||||
}
|
||||
|
||||
if (vae) {
|
||||
upsertMetadata(graph, { vae });
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(vae.key, isVAEModelConfig);
|
||||
const vaeMetadata: ModelMetadataField = getModelMetadataField(modelConfig);
|
||||
upsertMetadata(graph, { vae: vaeMetadata });
|
||||
}
|
||||
};
|
||||
|
@ -10,46 +10,46 @@ import { buildCanvasSDXLOutpaintGraph } from './buildCanvasSDXLOutpaintGraph';
|
||||
import { buildCanvasSDXLTextToImageGraph } from './buildCanvasSDXLTextToImageGraph';
|
||||
import { buildCanvasTextToImageGraph } from './buildCanvasTextToImageGraph';
|
||||
|
||||
export const buildCanvasGraph = (
|
||||
export const buildCanvasGraph = async (
|
||||
state: RootState,
|
||||
generationMode: 'txt2img' | 'img2img' | 'inpaint' | 'outpaint',
|
||||
canvasInitImage: ImageDTO | undefined,
|
||||
canvasMaskImage: ImageDTO | undefined
|
||||
) => {
|
||||
): Promise<NonNullableGraph> => {
|
||||
let graph: NonNullableGraph;
|
||||
|
||||
if (generationMode === 'txt2img') {
|
||||
if (state.generation.model && state.generation.model.base === 'sdxl') {
|
||||
graph = buildCanvasSDXLTextToImageGraph(state);
|
||||
graph = await buildCanvasSDXLTextToImageGraph(state);
|
||||
} else {
|
||||
graph = buildCanvasTextToImageGraph(state);
|
||||
graph = await buildCanvasTextToImageGraph(state);
|
||||
}
|
||||
} else if (generationMode === 'img2img') {
|
||||
if (!canvasInitImage) {
|
||||
throw new Error('Missing canvas init image');
|
||||
}
|
||||
if (state.generation.model && state.generation.model.base === 'sdxl') {
|
||||
graph = buildCanvasSDXLImageToImageGraph(state, canvasInitImage);
|
||||
graph = await buildCanvasSDXLImageToImageGraph(state, canvasInitImage);
|
||||
} else {
|
||||
graph = buildCanvasImageToImageGraph(state, canvasInitImage);
|
||||
graph = await buildCanvasImageToImageGraph(state, canvasInitImage);
|
||||
}
|
||||
} else if (generationMode === 'inpaint') {
|
||||
if (!canvasInitImage || !canvasMaskImage) {
|
||||
throw new Error('Missing canvas init and mask images');
|
||||
}
|
||||
if (state.generation.model && state.generation.model.base === 'sdxl') {
|
||||
graph = buildCanvasSDXLInpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
graph = await buildCanvasSDXLInpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
} else {
|
||||
graph = buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
graph = await buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
}
|
||||
} else {
|
||||
if (!canvasInitImage) {
|
||||
throw new Error('Missing canvas init image');
|
||||
}
|
||||
if (state.generation.model && state.generation.model.base === 'sdxl') {
|
||||
graph = buildCanvasSDXLOutpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
graph = await buildCanvasSDXLOutpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
} else {
|
||||
graph = buildCanvasOutpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
graph = await buildCanvasOutpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,13 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { ImageDTO, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
|
||||
import {
|
||||
type ImageDTO,
|
||||
type ImageToLatentsInvocation,
|
||||
isNonRefinerMainModelConfig,
|
||||
type NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
@ -25,12 +31,15 @@ import {
|
||||
POSITIVE_CONDITIONING,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Image to Image graph.
|
||||
*/
|
||||
export const buildCanvasImageToImageGraph = (state: RootState, initialImage: ImageDTO): NonNullableGraph => {
|
||||
export const buildCanvasImageToImageGraph = async (
|
||||
state: RootState,
|
||||
initialImage: ImageDTO
|
||||
): Promise<NonNullableGraph> => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -306,6 +315,8 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima
|
||||
});
|
||||
}
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
@ -316,7 +327,7 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima
|
||||
height: !isUsingScaledDimensions ? height : scaledBoundingBoxDimensions.height,
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
model,
|
||||
model: getModelMetadataField(modelConfig),
|
||||
seed,
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
@ -335,17 +346,17 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima
|
||||
}
|
||||
|
||||
// add LoRA support
|
||||
addLoRAsToGraph(state, graph, DENOISE_LATENTS);
|
||||
await addLoRAsToGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// optionally add custom VAE
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -40,11 +40,11 @@ import { getBoardField, getIsIntermediate } from './graphBuilderUtils';
|
||||
/**
|
||||
* Builds the Canvas tab's Inpaint graph.
|
||||
*/
|
||||
export const buildCanvasInpaintGraph = (
|
||||
export const buildCanvasInpaintGraph = async (
|
||||
state: RootState,
|
||||
canvasInitImage: ImageDTO,
|
||||
canvasMaskImage: ImageDTO
|
||||
): NonNullableGraph => {
|
||||
): Promise<NonNullableGraph> => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -414,17 +414,17 @@ export const buildCanvasInpaintGraph = (
|
||||
}
|
||||
|
||||
// Add VAE
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
// must add before watermarker!
|
||||
|
@ -44,11 +44,11 @@ import { getBoardField, getIsIntermediate } from './graphBuilderUtils';
|
||||
/**
|
||||
* Builds the Canvas tab's Outpaint graph.
|
||||
*/
|
||||
export const buildCanvasOutpaintGraph = (
|
||||
export const buildCanvasOutpaintGraph = async (
|
||||
state: RootState,
|
||||
canvasInitImage: ImageDTO,
|
||||
canvasMaskImage?: ImageDTO
|
||||
): NonNullableGraph => {
|
||||
): Promise<NonNullableGraph> => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -545,18 +545,18 @@ export const buildCanvasOutpaintGraph = (
|
||||
}
|
||||
|
||||
// Add VAE
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -1,6 +1,12 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { ImageDTO, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import {
|
||||
type ImageDTO,
|
||||
type ImageToLatentsInvocation,
|
||||
isNonRefinerMainModelConfig,
|
||||
type NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
@ -26,12 +32,15 @@ import {
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Image to Image graph.
|
||||
*/
|
||||
export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage: ImageDTO): NonNullableGraph => {
|
||||
export const buildCanvasSDXLImageToImageGraph = async (
|
||||
state: RootState,
|
||||
initialImage: ImageDTO
|
||||
): Promise<NonNullableGraph> => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -307,6 +316,8 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage:
|
||||
});
|
||||
}
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
@ -317,7 +328,7 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage:
|
||||
height: !isUsingScaledDimensions ? height : scaledBoundingBoxDimensions.height,
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
model,
|
||||
model: getModelMetadataField(modelConfig),
|
||||
seed,
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
@ -338,24 +349,24 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage:
|
||||
|
||||
// Add Refiner if enabled
|
||||
if (refinerModel) {
|
||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||
}
|
||||
}
|
||||
|
||||
// optionally add custom VAE
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -41,11 +41,11 @@ import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBu
|
||||
/**
|
||||
* Builds the Canvas tab's Inpaint graph.
|
||||
*/
|
||||
export const buildCanvasSDXLInpaintGraph = (
|
||||
export const buildCanvasSDXLInpaintGraph = async (
|
||||
state: RootState,
|
||||
canvasInitImage: ImageDTO,
|
||||
canvasMaskImage: ImageDTO
|
||||
): NonNullableGraph => {
|
||||
): Promise<NonNullableGraph> => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -426,24 +426,31 @@ export const buildCanvasSDXLInpaintGraph = (
|
||||
|
||||
// Add Refiner if enabled
|
||||
if (refinerModel) {
|
||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId, canvasInitImage, canvasMaskImage);
|
||||
await addSDXLRefinerToGraph(
|
||||
state,
|
||||
graph,
|
||||
SDXL_DENOISE_LATENTS,
|
||||
modelLoaderNodeId,
|
||||
canvasInitImage,
|
||||
canvasMaskImage
|
||||
);
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||
}
|
||||
}
|
||||
|
||||
// Add VAE
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -45,11 +45,11 @@ import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBu
|
||||
/**
|
||||
* Builds the Canvas tab's Outpaint graph.
|
||||
*/
|
||||
export const buildCanvasSDXLOutpaintGraph = (
|
||||
export const buildCanvasSDXLOutpaintGraph = async (
|
||||
state: RootState,
|
||||
canvasInitImage: ImageDTO,
|
||||
canvasMaskImage?: ImageDTO
|
||||
): NonNullableGraph => {
|
||||
): Promise<NonNullableGraph> => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -555,25 +555,25 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
|
||||
// Add Refiner if enabled
|
||||
if (refinerModel) {
|
||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId, canvasInitImage);
|
||||
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId, canvasInitImage);
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||
}
|
||||
}
|
||||
|
||||
// Add VAE
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -1,6 +1,7 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { NonNullableGraph } from 'services/api/types';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
@ -24,12 +25,12 @@ import {
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Text to Image graph.
|
||||
*/
|
||||
export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGraph => {
|
||||
export const buildCanvasSDXLTextToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -272,6 +273,8 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr
|
||||
});
|
||||
}
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
@ -284,7 +287,7 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr
|
||||
negative_prompt: negativePrompt,
|
||||
positive_style_prompt: positiveStylePrompt,
|
||||
negative_style_prompt: negativeStylePrompt,
|
||||
model,
|
||||
model: getModelMetadataField(modelConfig),
|
||||
seed,
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
@ -301,24 +304,24 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr
|
||||
|
||||
// Add Refiner if enabled
|
||||
if (refinerModel) {
|
||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||
}
|
||||
}
|
||||
|
||||
// add LoRA support
|
||||
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// optionally add custom VAE
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -1,7 +1,8 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { NonNullableGraph } from 'services/api/types';
|
||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
@ -23,12 +24,12 @@ import {
|
||||
POSITIVE_CONDITIONING,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Text to Image graph.
|
||||
*/
|
||||
export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph => {
|
||||
export const buildCanvasTextToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -262,6 +263,8 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph
|
||||
});
|
||||
}
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
@ -272,7 +275,7 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph
|
||||
height: !isUsingScaledDimensions ? height : scaledBoundingBoxDimensions.height,
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
model,
|
||||
model: getModelMetadataField(modelConfig),
|
||||
seed,
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
@ -289,17 +292,17 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph
|
||||
}
|
||||
|
||||
// optionally add custom VAE
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -1,7 +1,13 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
|
||||
import {
|
||||
type ImageResizeInvocation,
|
||||
type ImageToLatentsInvocation,
|
||||
isNonRefinerMainModelConfig,
|
||||
type NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
@ -24,12 +30,12 @@ import {
|
||||
RESIZE,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Image to Image tab graph.
|
||||
*/
|
||||
export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph => {
|
||||
export const buildLinearImageToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -307,6 +313,8 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph
|
||||
});
|
||||
}
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
@ -317,7 +325,7 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph
|
||||
width,
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
model,
|
||||
model: getModelMetadataField(modelConfig),
|
||||
seed,
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
@ -336,17 +344,17 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph
|
||||
}
|
||||
|
||||
// optionally add custom VAE
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -1,6 +1,12 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import {
|
||||
type ImageResizeInvocation,
|
||||
type ImageToLatentsInvocation,
|
||||
isNonRefinerMainModelConfig,
|
||||
type NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
@ -25,12 +31,12 @@ import {
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Image to Image tab graph.
|
||||
*/
|
||||
export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableGraph => {
|
||||
export const buildLinearSDXLImageToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -318,6 +324,8 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG
|
||||
});
|
||||
}
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
@ -328,7 +336,7 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG
|
||||
width,
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
model,
|
||||
model: getModelMetadataField(modelConfig),
|
||||
seed,
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
@ -349,25 +357,25 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG
|
||||
|
||||
// Add Refiner if enabled
|
||||
if (refinerModel) {
|
||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||
}
|
||||
}
|
||||
|
||||
// optionally add custom VAE
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// Add LoRA Support
|
||||
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -1,6 +1,7 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { NonNullableGraph } from 'services/api/types';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
@ -23,9 +24,9 @@ import {
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||
|
||||
export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGraph => {
|
||||
export const buildLinearSDXLTextToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -221,6 +222,8 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr
|
||||
],
|
||||
};
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
@ -231,7 +234,7 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr
|
||||
width,
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
model,
|
||||
model: getModelMetadataField(modelConfig),
|
||||
seed,
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
@ -250,25 +253,25 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr
|
||||
|
||||
// Add Refiner if enabled
|
||||
if (refinerModel) {
|
||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||
}
|
||||
}
|
||||
|
||||
// optionally add custom VAE
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -1,7 +1,8 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { NonNullableGraph } from 'services/api/types';
|
||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addHrfToGraph } from './addHrfToGraph';
|
||||
@ -23,9 +24,9 @@ import {
|
||||
SEAMLESS,
|
||||
TEXT_TO_IMAGE_GRAPH,
|
||||
} from './constants';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||
|
||||
export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph => {
|
||||
export const buildLinearTextToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -212,6 +213,8 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph
|
||||
],
|
||||
};
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
@ -222,7 +225,7 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph
|
||||
width,
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
model,
|
||||
model: getModelMetadataField(modelConfig),
|
||||
seed,
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
@ -239,18 +242,18 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph
|
||||
}
|
||||
|
||||
// optionally add custom VAE
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// High resolution fix.
|
||||
if (state.hrf.hrfEnabled) {
|
||||
|
@ -1,5 +1,5 @@
|
||||
import type { JSONObject } from 'common/types';
|
||||
import type { CoreMetadataInvocation, NonNullableGraph } from 'services/api/types';
|
||||
import type { AnyModelConfig, CoreMetadataInvocation, ModelMetadataField, NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { METADATA } from './constants';
|
||||
|
||||
@ -71,3 +71,11 @@ export const setMetadataReceivingNode = (graph: NonNullableGraph, nodeId: string
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const getModelMetadataField = ({ key, hash, name, base, type }: AnyModelConfig): ModelMetadataField => ({
|
||||
key,
|
||||
hash,
|
||||
name,
|
||||
base,
|
||||
type,
|
||||
});
|
||||
|
@ -38,6 +38,7 @@ export const ParamNegativePrompt = memo(() => {
|
||||
onKeyDown={onKeyDown}
|
||||
fontSize="sm"
|
||||
variant="darkFilled"
|
||||
paddingRight={30}
|
||||
/>
|
||||
<PromptOverlayButtonWrapper>
|
||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||
|
@ -54,6 +54,7 @@ export const ParamPositivePrompt = memo(() => {
|
||||
minH={28}
|
||||
onKeyDown={onKeyDown}
|
||||
variant="darkFilled"
|
||||
paddingRight={30}
|
||||
/>
|
||||
<PromptOverlayButtonWrapper>
|
||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||
|
@ -41,6 +41,7 @@ export const ParamSDXLNegativeStylePrompt = memo(() => {
|
||||
onKeyDown={onKeyDown}
|
||||
fontSize="sm"
|
||||
variant="darkFilled"
|
||||
paddingRight={30}
|
||||
/>
|
||||
<PromptOverlayButtonWrapper>
|
||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||
|
@ -38,6 +38,7 @@ export const ParamSDXLPositiveStylePrompt = memo(() => {
|
||||
onKeyDown={onKeyDown}
|
||||
fontSize="sm"
|
||||
variant="darkFilled"
|
||||
paddingRight={30}
|
||||
/>
|
||||
<PromptOverlayButtonWrapper>
|
||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||
|
@ -109,7 +109,7 @@ export const imagesApi = api.injectEndpoints({
|
||||
}),
|
||||
clearIntermediates: build.mutation<number, void>({
|
||||
query: () => ({ url: buildImagesUrl('intermediates'), method: 'DELETE' }),
|
||||
invalidatesTags: ['IntermediatesCount'],
|
||||
invalidatesTags: ['IntermediatesCount', 'InvocationCacheStatus'],
|
||||
}),
|
||||
getImageDTO: build.query<ImageDTO, string>({
|
||||
query: (image_name) => ({ url: buildImagesUrl(`i/${image_name}`) }),
|
||||
@ -125,7 +125,7 @@ export const imagesApi = api.injectEndpoints({
|
||||
paths['/api/v1/images/i/{image_name}/workflow']['get']['responses']['200']['content']['application/json'],
|
||||
string
|
||||
>({
|
||||
query: (image_name) => ({ url: `images/i/${image_name}/workflow` }),
|
||||
query: (image_name) => ({ url: buildImagesUrl(`i/${image_name}/workflow`) }),
|
||||
providesTags: (result, error, image_name) => [{ type: 'ImageWorkflow', id: image_name }],
|
||||
keepUnusedDataFor: 86400, // 24 hours
|
||||
}),
|
||||
|
@ -23,7 +23,14 @@ export type UpdateModelArg = {
|
||||
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
|
||||
};
|
||||
|
||||
export type UpdateModelImageArg = {
|
||||
key: string;
|
||||
image: Blob;
|
||||
};
|
||||
|
||||
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
|
||||
type UpdateModelImageResponse =
|
||||
paths['/api/v2/models/i/{key}/image']['patch']['responses']['200']['content']['application/json'];
|
||||
|
||||
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
@ -33,6 +40,7 @@ type DeleteModelArg = {
|
||||
key: string;
|
||||
};
|
||||
type DeleteModelResponse = void;
|
||||
type DeleteModelImageResponse = void;
|
||||
|
||||
type ConvertMainModelResponse =
|
||||
paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];
|
||||
@ -144,6 +152,18 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
updateModelImage: build.mutation<UpdateModelImageResponse, UpdateModelImageArg>({
|
||||
query: ({ key, image }) => {
|
||||
const formData = new FormData();
|
||||
formData.append('image', image);
|
||||
return {
|
||||
url: buildModelsUrl(`i/${key}/image`),
|
||||
method: 'PATCH',
|
||||
body: formData,
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
|
||||
query: ({ source }) => {
|
||||
return {
|
||||
@ -163,6 +183,18 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
deleteModelImage: build.mutation<DeleteModelImageResponse, string>({
|
||||
query: (key) => {
|
||||
return {
|
||||
url: buildModelsUrl(`i/${key}/image`),
|
||||
method: 'DELETE',
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
getModelImage: build.query<string, string>({
|
||||
query: (key) => buildModelsUrl(`i/${key}/image`),
|
||||
}),
|
||||
convertModel: build.mutation<ConvertMainModelResponse, string>({
|
||||
query: (key) => {
|
||||
return {
|
||||
@ -330,7 +362,9 @@ export const {
|
||||
useGetTextualInversionModelsQuery,
|
||||
useGetVaeModelsQuery,
|
||||
useDeleteModelsMutation,
|
||||
useDeleteModelImageMutation,
|
||||
useUpdateModelMutation,
|
||||
useUpdateModelImageMutation,
|
||||
useInstallModelMutation,
|
||||
useConvertModelMutation,
|
||||
useSyncModelsMutation,
|
||||
|
File diff suppressed because one or more lines are too long
@ -38,7 +38,6 @@ export type OffsetPaginatedResults_ImageDTO_ = S['OffsetPaginatedResults_ImageDT
|
||||
export type ModelType = S['ModelType'];
|
||||
export type SubModelType = S['SubModelType'];
|
||||
export type BaseModelType = S['BaseModelType'];
|
||||
export type ControlField = S['ControlField'];
|
||||
|
||||
// Model Configs
|
||||
|
||||
@ -120,17 +119,18 @@ export type CreateGradientMaskInvocation = S['CreateGradientMaskInvocation'];
|
||||
export type CanvasPasteBackInvocation = S['CanvasPasteBackInvocation'];
|
||||
export type NoiseInvocation = S['NoiseInvocation'];
|
||||
export type DenoiseLatentsInvocation = S['DenoiseLatentsInvocation'];
|
||||
export type SDXLLoraLoaderInvocation = S['SDXLLoraLoaderInvocation'];
|
||||
export type SDXLLoRALoaderInvocation = S['SDXLLoRALoaderInvocation'];
|
||||
export type ImageToLatentsInvocation = S['ImageToLatentsInvocation'];
|
||||
export type LatentsToImageInvocation = S['LatentsToImageInvocation'];
|
||||
export type LoraLoaderInvocation = S['LoraLoaderInvocation'];
|
||||
export type LoRALoaderInvocation = S['LoRALoaderInvocation'];
|
||||
export type ESRGANInvocation = S['ESRGANInvocation'];
|
||||
export type ImageNSFWBlurInvocation = S['ImageNSFWBlurInvocation'];
|
||||
export type ImageWatermarkInvocation = S['ImageWatermarkInvocation'];
|
||||
export type SeamlessModeInvocation = S['SeamlessModeInvocation'];
|
||||
export type CoreMetadataInvocation = S['CoreMetadataInvocation'];
|
||||
export type IPAdapterMetadataField = S['IPAdapterMetadataField'];
|
||||
export type T2IAdapterField = S['T2IAdapterField'];
|
||||
|
||||
// Metadata fields
|
||||
export type ModelMetadataField = S['ModelMetadataField'];
|
||||
|
||||
// ControlNet Nodes
|
||||
export type ControlNetInvocation = S['ControlNetInvocation'];
|
||||
|
@ -33,19 +33,15 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.latent import SchedulerOutput
|
||||
from invokeai.app.invocations.metadata import MetadataItemField, MetadataItemOutput, MetadataOutput
|
||||
from invokeai.app.invocations.model import (
|
||||
ClipField,
|
||||
CLIPField,
|
||||
CLIPOutput,
|
||||
LoraInfo,
|
||||
LoraLoaderOutput,
|
||||
LoRAModelField,
|
||||
MainModelField,
|
||||
ModelInfo,
|
||||
LoRALoaderOutput,
|
||||
ModelField,
|
||||
ModelLoaderOutput,
|
||||
SDXLLoraLoaderOutput,
|
||||
SDXLLoRALoaderOutput,
|
||||
UNetField,
|
||||
UNetOutput,
|
||||
VaeField,
|
||||
VAEModelField,
|
||||
VAEField,
|
||||
VAEOutput,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import (
|
||||
@ -73,8 +69,8 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.model_management.model_manager import LoadedModelInfo
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
@ -118,20 +114,16 @@ __all__ = [
|
||||
"MetadataItemOutput",
|
||||
"MetadataOutput",
|
||||
# invokeai.app.invocations.model
|
||||
"ModelInfo",
|
||||
"LoraInfo",
|
||||
"ModelField",
|
||||
"UNetField",
|
||||
"ClipField",
|
||||
"VaeField",
|
||||
"MainModelField",
|
||||
"LoRAModelField",
|
||||
"VAEModelField",
|
||||
"CLIPField",
|
||||
"VAEField",
|
||||
"UNetOutput",
|
||||
"VAEOutput",
|
||||
"CLIPOutput",
|
||||
"ModelLoaderOutput",
|
||||
"LoraLoaderOutput",
|
||||
"SDXLLoraLoaderOutput",
|
||||
"LoRALoaderOutput",
|
||||
"SDXLLoRALoaderOutput",
|
||||
# invokeai.app.invocations.primitives
|
||||
"BooleanCollectionOutput",
|
||||
"BooleanOutput",
|
||||
@ -166,7 +158,7 @@ __all__ = [
|
||||
# invokeai.app.services.config.config_default
|
||||
"InvokeAIAppConfig",
|
||||
# invokeai.backend.model_management.model_manager
|
||||
"LoadedModelInfo",
|
||||
"LoadedModel",
|
||||
# invokeai.backend.model_management.models.base
|
||||
"BaseModelType",
|
||||
"ModelType",
|
||||
|
17
scripts/generate_openapi_schema.py
Normal file
17
scripts/generate_openapi_schema.py
Normal file
@ -0,0 +1,17 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def main():
|
||||
# Change working directory to the repo root
|
||||
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from invokeai.app.api_app import custom_openapi
|
||||
|
||||
schema = custom_openapi()
|
||||
json.dump(schema, sys.stdout, indent=2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -14,17 +14,13 @@ def test_loading(mm2_model_manager: ModelManagerServiceBase, embedding_file: Pat
|
||||
matches = store.search_by_attr(model_name="test_embedding")
|
||||
assert len(matches) == 0
|
||||
key = mm2_model_manager.install.register_path(embedding_file)
|
||||
loaded_model = mm2_model_manager.load_model_by_config(store.get_model(key))
|
||||
loaded_model = mm2_model_manager.load.load_model(store.get_model(key))
|
||||
assert loaded_model is not None
|
||||
assert loaded_model.config.key == key
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, TextualInversionModelRaw)
|
||||
loaded_model_2 = mm2_model_manager.load_model_by_key(key)
|
||||
assert loaded_model.config.key == loaded_model_2.config.key
|
||||
|
||||
loaded_model_3 = mm2_model_manager.load_model_by_attr(
|
||||
model_name=loaded_model.config.name,
|
||||
model_type=loaded_model.config.type,
|
||||
base_model=loaded_model.config.base,
|
||||
)
|
||||
assert loaded_model.config.key == loaded_model_3.config.key
|
||||
config = mm2_model_manager.store.get_model(key)
|
||||
loaded_model_2 = mm2_model_manager.load.load_model(config)
|
||||
|
||||
assert loaded_model.config.key == loaded_model_2.config.key
|
||||
|
@ -44,6 +44,7 @@ def mock_services() -> InvocationServices:
|
||||
images=ImageService(),
|
||||
invocation_cache=MemoryInvocationCache(max_cache_size=0),
|
||||
logger=logging, # type: ignore
|
||||
model_images=None, # type: ignore
|
||||
model_manager=None, # type: ignore
|
||||
download_queue=None, # type: ignore
|
||||
names=None, # type: ignore
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user