mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
5 Commits
speed-up-h
...
separate-g
Author | SHA1 | Date | |
---|---|---|---|
cd3f5f30dc | |||
71ee28ac12 | |||
46c904d08a | |||
7d5a88b69d | |||
afa4df1991 |
8
Makefile
8
Makefile
@ -10,11 +10,10 @@ help:
|
||||
@echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting"
|
||||
@echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors"
|
||||
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
|
||||
@echo "test Run the unit tests."
|
||||
@echo "frontend-install Install the pnpm modules needed for the front end"
|
||||
@echo "test" Run the unit tests.
|
||||
@echo "frontend-install" Install the pnpm modules needed for the front end
|
||||
@echo "frontend-build Build the frontend in order to run on localhost:9090"
|
||||
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
|
||||
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
|
||||
@echo "installer-zip Build the installer .zip file for the current version"
|
||||
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
||||
|
||||
@ -54,9 +53,6 @@ frontend-build:
|
||||
frontend-dev:
|
||||
cd invokeai/frontend/web && pnpm dev
|
||||
|
||||
frontend-typegen:
|
||||
cd invokeai/frontend/web && python ../../../scripts/generate_openapi_schema.py | pnpm typegen
|
||||
|
||||
# Installer zip file
|
||||
installer-zip:
|
||||
cd installer && ./create_installer.sh
|
||||
|
@ -25,7 +25,6 @@ from ..services.invocation_cache.invocation_cache_memory import MemoryInvocation
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.model_images.model_images_default import ModelImageFileStorageDisk
|
||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||
from ..services.model_records import ModelRecordServiceSQL
|
||||
from ..services.names.names_default import SimpleNameService
|
||||
@ -72,8 +71,6 @@ class ApiDependencies:
|
||||
|
||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||
|
||||
model_images_folder = config.models_path
|
||||
|
||||
db = init_db(config=config, logger=logger, image_files=image_files)
|
||||
|
||||
configuration = config
|
||||
@ -95,7 +92,6 @@ class ApiDependencies:
|
||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||
)
|
||||
download_queue_service = DownloadQueueService(event_bus=events)
|
||||
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
||||
model_manager = ModelManagerService.build_model_manager(
|
||||
app_config=configuration,
|
||||
model_record_service=ModelRecordServiceSQL(db=db),
|
||||
@ -122,7 +118,6 @@ class ApiDependencies:
|
||||
images=images,
|
||||
invocation_cache=invocation_cache,
|
||||
logger=logger,
|
||||
model_images=model_images_service,
|
||||
model_manager=model_manager,
|
||||
download_queue=download_queue_service,
|
||||
names=names,
|
||||
|
@ -1,16 +1,12 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein
|
||||
"""FastAPI route for model configuration records."""
|
||||
|
||||
import io
|
||||
import pathlib
|
||||
import shutil
|
||||
import traceback
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import Body, Path, Query, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
@ -35,9 +31,6 @@ from ..dependencies import ApiDependencies
|
||||
|
||||
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
|
||||
|
||||
# images are immutable; set a high max-age
|
||||
IMAGE_MAX_AGE = 31536000
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
"""Return list of configs."""
|
||||
@ -112,9 +105,6 @@ async def list_model_records(
|
||||
found_models.extend(
|
||||
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
|
||||
)
|
||||
for model in found_models:
|
||||
cover_image = ApiDependencies.invoker.services.model_images.get_url(model.key)
|
||||
model.cover_image = cover_image
|
||||
return ModelsList(models=found_models)
|
||||
|
||||
|
||||
@ -158,8 +148,6 @@ async def get_model_record(
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
try:
|
||||
config: AnyModelConfig = record_store.get_model(key)
|
||||
cover_image = ApiDependencies.invoker.services.model_images.get_url(key)
|
||||
config.cover_image = cover_image
|
||||
return config
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@ -278,75 +266,6 @@ async def update_model_record(
|
||||
return model_response
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/i/{key}/image",
|
||||
operation_id="get_model_image",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The model image was fetched successfully",
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The model image could not be found"},
|
||||
},
|
||||
status_code=200,
|
||||
)
|
||||
async def get_model_image(
|
||||
key: str = Path(description="The name of model image file to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets an image file that previews the model"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.model_images.get_path(key)
|
||||
|
||||
response = FileResponse(
|
||||
path,
|
||||
media_type="image/png",
|
||||
filename=key + ".png",
|
||||
content_disposition_type="inline",
|
||||
)
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
return response
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/i/{key}/image",
|
||||
operation_id="update_model_image",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The model image was updated successfully",
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
status_code=200,
|
||||
)
|
||||
async def update_model_image(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
image: UploadFile,
|
||||
) -> None:
|
||||
if not image.content_type or not image.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
contents = await image.read()
|
||||
try:
|
||||
pil_image = Image.open(io.BytesIO(contents))
|
||||
|
||||
except Exception:
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
model_images = ApiDependencies.invoker.services.model_images
|
||||
try:
|
||||
model_images.save(pil_image, key)
|
||||
logger.info(f"Updated image for model: {key}")
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
return
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/i/{key}",
|
||||
operation_id="delete_model",
|
||||
@ -377,29 +296,6 @@ async def delete_model(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/i/{key}/image",
|
||||
operation_id="delete_model_image",
|
||||
responses={
|
||||
204: {"description": "Model image deleted successfully"},
|
||||
404: {"description": "Model image not found"},
|
||||
},
|
||||
status_code=204,
|
||||
)
|
||||
async def delete_model_image(
|
||||
key: str = Path(description="Unique key of model image to remove from model_images directory."),
|
||||
) -> None:
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
model_images = ApiDependencies.invoker.services.model_images
|
||||
try:
|
||||
model_images.delete(key)
|
||||
logger.info(f"Deleted model image: {key}")
|
||||
return
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
# @model_manager_router.post(
|
||||
# "/i/",
|
||||
# operation_id="add_model_record",
|
||||
@ -643,7 +539,7 @@ async def convert_model(
|
||||
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
|
||||
|
||||
# loading the model will convert it into a cached diffusers file
|
||||
model_manager.load.load_model(model_config, submodel_type=SubModelType.Scheduler)
|
||||
model_manager.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
|
||||
|
||||
# Get the path of the converted model from the loader
|
||||
cache_path = loader.convert_cache.cache_path(key)
|
||||
|
@ -2,9 +2,12 @@
|
||||
# which are imported/used before parse_args() is called will get the default config values instead of the
|
||||
# values from the command line or config file.
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
||||
from .services.config import InvokeAIAppConfig
|
||||
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
@ -17,7 +20,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
import asyncio
|
||||
import mimetypes
|
||||
import socket
|
||||
from contextlib import asynccontextmanager
|
||||
from inspect import signature
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@ -38,7 +40,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
# noinspection PyUnresolvedReferences
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
import invokeai.frontend.web as web_dir
|
||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
from .api.dependencies import ApiDependencies
|
||||
@ -58,7 +59,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
BaseInvocation,
|
||||
UIConfigBase,
|
||||
)
|
||||
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
||||
|
||||
if is_mps_available():
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
|
@ -20,7 +20,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
from invokeai.backend.util.devices import torch_dtype
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from .model import CLIPField
|
||||
from .model import ClipField
|
||||
|
||||
# unconditioned: Optional[torch.Tensor]
|
||||
|
||||
@ -46,7 +46,7 @@ class CompelInvocation(BaseInvocation):
|
||||
description=FieldDescriptions.compel_prompt,
|
||||
ui_component=UIComponent.Textarea,
|
||||
)
|
||||
clip: CLIPField = InputField(
|
||||
clip: ClipField = InputField(
|
||||
title="CLIP",
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
@ -54,16 +54,16 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
|
||||
tokenizer_model = tokenizer_info.model
|
||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||
text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
|
||||
text_encoder_model = text_encoder_info.model
|
||||
assert isinstance(text_encoder_model, CLIPTextModel)
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
@ -127,16 +127,16 @@ class SDXLPromptInvocationBase:
|
||||
def run_clip_compel(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
clip_field: CLIPField,
|
||||
clip_field: ClipField,
|
||||
prompt: str,
|
||||
get_pooled: bool,
|
||||
lora_prefix: str,
|
||||
zero_on_empty: bool,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
|
||||
tokenizer_info = context.models.load(clip_field.tokenizer)
|
||||
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
|
||||
tokenizer_model = tokenizer_info.model
|
||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||
text_encoder_info = context.models.load(clip_field.text_encoder)
|
||||
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
|
||||
text_encoder_model = text_encoder_info.model
|
||||
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||
|
||||
@ -163,7 +163,7 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||
lora_model = lora_info.model
|
||||
assert isinstance(lora_model, LoRAModelRaw)
|
||||
yield (lora_model, lora.weight)
|
||||
@ -253,8 +253,8 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
crop_left: int = InputField(default=0, description="")
|
||||
target_width: int = InputField(default=1024, description="")
|
||||
target_height: int = InputField(default=1024, description="")
|
||||
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
||||
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
@ -340,7 +340,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
crop_top: int = InputField(default=0, description="")
|
||||
crop_left: int = InputField(default=0, description="")
|
||||
aesthetic_score: float = InputField(default=6.0, description=FieldDescriptions.sdxl_aesthetic)
|
||||
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
@ -370,10 +370,10 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
|
||||
|
||||
@invocation_output("clip_skip_output")
|
||||
class CLIPSkipInvocationOutput(BaseInvocationOutput):
|
||||
"""CLIP skip node output"""
|
||||
class ClipSkipInvocationOutput(BaseInvocationOutput):
|
||||
"""Clip skip node output"""
|
||||
|
||||
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -383,15 +383,15 @@ class CLIPSkipInvocationOutput(BaseInvocationOutput):
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
)
|
||||
class CLIPSkipInvocation(BaseInvocation):
|
||||
class ClipSkipInvocation(BaseInvocation):
|
||||
"""Skip layers in clip text_encoder model."""
|
||||
|
||||
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
|
||||
skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CLIPSkipInvocationOutput:
|
||||
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
|
||||
self.clip.skipped_layers += self.skipped_layers
|
||||
return CLIPSkipInvocationOutput(
|
||||
return ClipSkipInvocationOutput(
|
||||
clip=self.clip,
|
||||
)
|
||||
|
||||
|
@ -34,7 +34,6 @@ from invokeai.app.invocations.fields import (
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
@ -52,9 +51,15 @@ CONTROLNET_RESIZE_VALUES = Literal[
|
||||
]
|
||||
|
||||
|
||||
class ControlNetModelField(BaseModel):
|
||||
"""ControlNet model field"""
|
||||
|
||||
key: str = Field(description="Model config record key for the ControlNet model")
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
control_model: ModelField = Field(description="The ControlNet model to use")
|
||||
control_model: ControlNetModelField = Field(description="The ControlNet model to use")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
@ -90,7 +95,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
|
||||
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||
)
|
||||
|
@ -228,7 +228,7 @@ class ConditioningField(BaseModel):
|
||||
# endregion
|
||||
|
||||
|
||||
class MetadataField(RootModel[dict[str, Any]]):
|
||||
class MetadataField(RootModel):
|
||||
"""
|
||||
Pydantic model for metadata with custom root of type dict[str, Any].
|
||||
Metadata is stored without a strict schema.
|
||||
|
@ -11,17 +11,25 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import ModelField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import BaseModelType, IPAdapterConfig, ModelType
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||
|
||||
|
||||
# LS: Consider moving these two classes into model.py
|
||||
class IPAdapterModelField(BaseModel):
|
||||
key: str = Field(description="Key to the IP-Adapter model")
|
||||
|
||||
|
||||
class CLIPVisionModelField(BaseModel):
|
||||
key: str = Field(description="Key to the CLIP Vision image encoder model")
|
||||
|
||||
|
||||
class IPAdapterField(BaseModel):
|
||||
image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).")
|
||||
ip_adapter_model: ModelField = Field(description="The IP-Adapter model to use.")
|
||||
image_encoder_model: ModelField = Field(description="The name of the CLIP image encoder model.")
|
||||
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
|
||||
image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.")
|
||||
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
||||
@ -54,7 +62,7 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
|
||||
ip_adapter_model: ModelField = InputField(
|
||||
ip_adapter_model: IPAdapterModelField = InputField(
|
||||
description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
|
||||
)
|
||||
|
||||
@ -82,18 +90,18 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
||||
assert isinstance(ip_adapter_info, IPAdapterConfig)
|
||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||
image_encoder_models = context.models.search_by_attrs(
|
||||
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
|
||||
)
|
||||
assert len(image_encoder_models) == 1
|
||||
image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key)
|
||||
return IPAdapterOutput(
|
||||
ip_adapter=IPAdapterField(
|
||||
image=self.image,
|
||||
ip_adapter_model=self.ip_adapter_model,
|
||||
image_encoder_model=ModelField(key=image_encoder_models[0].key),
|
||||
image_encoder_model=image_encoder_model,
|
||||
weight=self.weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
|
@ -26,7 +26,6 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from PIL import Image, ImageFilter
|
||||
from pydantic import field_validator
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
||||
from invokeai.app.invocations.fields import (
|
||||
@ -76,7 +75,7 @@ from .baseinvocation import (
|
||||
invocation_output,
|
||||
)
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .model import ModelField, UNetField, VAEField
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
@ -119,7 +118,7 @@ class SchedulerInvocation(BaseInvocation):
|
||||
class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
"""Creates mask for denoising model run."""
|
||||
|
||||
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
|
||||
vae: VaeField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
|
||||
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
|
||||
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
|
||||
@ -154,7 +153,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
if image_tensor is not None:
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||
|
||||
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||
@ -245,12 +244,12 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
||||
|
||||
def get_scheduler(
|
||||
context: InvocationContext,
|
||||
scheduler_info: ModelField,
|
||||
scheduler_info: ModelInfo,
|
||||
scheduler_name: str,
|
||||
seed: int,
|
||||
) -> Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.models.load(scheduler_info)
|
||||
orig_scheduler_info = context.models.load(**scheduler_info.model_dump())
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
scheduler_config = orig_scheduler.config
|
||||
|
||||
@ -462,7 +461,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# and if weight is None, populate with default 1.0?
|
||||
controlnet_data = []
|
||||
for control_info in control_list:
|
||||
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
|
||||
control_model = exit_stack.enter_context(context.models.load(key=control_info.control_model.key))
|
||||
|
||||
# control_models.append(control_model)
|
||||
control_image_field = control_info.image
|
||||
@ -524,10 +523,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
conditioning_data.ip_adapter_conditioning = []
|
||||
for single_ip_adapter in ip_adapter:
|
||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||
context.models.load(single_ip_adapter.ip_adapter_model)
|
||||
context.models.load(key=single_ip_adapter.ip_adapter_model.key)
|
||||
)
|
||||
|
||||
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
|
||||
image_encoder_model_info = context.models.load(key=single_ip_adapter.image_encoder_model.key)
|
||||
|
||||
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
||||
single_ipa_image_fields = single_ip_adapter.image
|
||||
if not isinstance(single_ipa_image_fields, list):
|
||||
@ -538,7 +538,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||
with image_encoder_model_info as image_encoder_model:
|
||||
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
|
||||
# Get image embeddings from CLIP and ImageProjModel.
|
||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
||||
single_ipa_images, image_encoder_model
|
||||
@ -578,8 +577,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
t2i_adapter_data = []
|
||||
for t2i_adapter_field in t2i_adapter:
|
||||
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
|
||||
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
|
||||
t2i_adapter_model_config = context.models.get_config(key=t2i_adapter_field.t2i_adapter_model.key)
|
||||
t2i_adapter_loaded_model = context.models.load(key=t2i_adapter_field.t2i_adapter_model.key)
|
||||
image = context.images.get_pil(t2i_adapter_field.image.image_name)
|
||||
|
||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||
@ -732,13 +731,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
unet_info = context.models.load(**self.unet.unet.model_dump())
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
@ -832,7 +830,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
vae: VaeField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
@ -843,8 +841,8 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL))
|
||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||
|
||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
latents = latents.to(vae.device)
|
||||
@ -1010,7 +1008,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
image: ImageField = InputField(
|
||||
description="The image to encode",
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
vae: VaeField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
@ -1066,7 +1064,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
|
@ -8,10 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import (
|
||||
CONTROLNET_MODE_VALUES,
|
||||
CONTROLNET_RESIZE_VALUES,
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
@ -20,8 +17,10 @@ from invokeai.app.invocations.fields import (
|
||||
OutputField,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
|
||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||
|
||||
from ...version import __version__
|
||||
|
||||
@ -31,20 +30,10 @@ class MetadataItemField(BaseModel):
|
||||
value: Any = Field(description=FieldDescriptions.metadata_item_value)
|
||||
|
||||
|
||||
class ModelMetadataField(BaseModel):
|
||||
"""Model Metadata Field"""
|
||||
|
||||
key: str
|
||||
hash: str
|
||||
name: str
|
||||
base: BaseModelType
|
||||
type: ModelType
|
||||
|
||||
|
||||
class LoRAMetadataField(BaseModel):
|
||||
"""LoRA Metadata Field"""
|
||||
|
||||
model: ModelMetadataField = Field(description=FieldDescriptions.lora_model)
|
||||
model: LoRAModelField = Field(description=FieldDescriptions.lora_model)
|
||||
weight: float = Field(description=FieldDescriptions.lora_weight)
|
||||
|
||||
|
||||
@ -52,7 +41,7 @@ class IPAdapterMetadataField(BaseModel):
|
||||
"""IP Adapter Field, minus the CLIP Vision Encoder model"""
|
||||
|
||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||
ip_adapter_model: ModelMetadataField = Field(
|
||||
ip_adapter_model: IPAdapterModelField = Field(
|
||||
description="The IP-Adapter model.",
|
||||
)
|
||||
weight: Union[float, list[float]] = Field(
|
||||
@ -62,33 +51,6 @@ class IPAdapterMetadataField(BaseModel):
|
||||
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
|
||||
|
||||
|
||||
class T2IAdapterMetadataField(BaseModel):
|
||||
image: ImageField = Field(description="The T2I-Adapter image prompt.")
|
||||
t2i_adapter_model: ModelMetadataField = Field(description="The T2I-Adapter model to use.")
|
||||
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)"
|
||||
)
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
|
||||
class ControlNetMetadataField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
control_model: ModelMetadataField = Field(description="The ControlNet model to use")
|
||||
control_weight: Union[float, list[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
|
||||
@invocation_output("metadata_item_output")
|
||||
class MetadataItemOutput(BaseInvocationOutput):
|
||||
"""Metadata Item Output"""
|
||||
@ -178,14 +140,14 @@ class CoreMetadataInvocation(BaseInvocation):
|
||||
default=None,
|
||||
description="The number of skipped CLIP layers",
|
||||
)
|
||||
model: Optional[ModelMetadataField] = InputField(default=None, description="The main model used for inference")
|
||||
controlnets: Optional[list[ControlNetMetadataField]] = InputField(
|
||||
model: Optional[MainModelField] = InputField(default=None, description="The main model used for inference")
|
||||
controlnets: Optional[list[ControlField]] = InputField(
|
||||
default=None, description="The ControlNets used for inference"
|
||||
)
|
||||
ipAdapters: Optional[list[IPAdapterMetadataField]] = InputField(
|
||||
default=None, description="The IP Adapters used for inference"
|
||||
)
|
||||
t2iAdapters: Optional[list[T2IAdapterMetadataField]] = InputField(
|
||||
t2iAdapters: Optional[list[T2IAdapterField]] = InputField(
|
||||
default=None, description="The IP Adapters used for inference"
|
||||
)
|
||||
loras: Optional[list[LoRAMetadataField]] = InputField(default=None, description="The LoRAs used for inference")
|
||||
@ -197,7 +159,7 @@ class CoreMetadataInvocation(BaseInvocation):
|
||||
default=None,
|
||||
description="The name of the initial image",
|
||||
)
|
||||
vae: Optional[ModelMetadataField] = InputField(
|
||||
vae: Optional[VAEModelField] = InputField(
|
||||
default=None,
|
||||
description="The VAE used for decoding, if the main model's default was not used",
|
||||
)
|
||||
@ -228,7 +190,7 @@ class CoreMetadataInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
# SDXL Refiner
|
||||
refiner_model: Optional[ModelMetadataField] = InputField(
|
||||
refiner_model: Optional[MainModelField] = InputField(
|
||||
default=None,
|
||||
description="The SDXL Refiner model used",
|
||||
)
|
||||
@ -260,9 +222,10 @@ class CoreMetadataInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> MetadataOutput:
|
||||
"""Collects and outputs a CoreMetadata object"""
|
||||
|
||||
as_dict = self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
|
||||
as_dict["app_version"] = __version__
|
||||
|
||||
return MetadataOutput(metadata=MetadataField.model_validate(as_dict))
|
||||
return MetadataOutput(
|
||||
metadata=MetadataField.model_validate(
|
||||
self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
|
||||
)
|
||||
)
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
@ -6,8 +6,8 @@ from pydantic import BaseModel, Field
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.model_manager.config import SubModelType
|
||||
|
||||
from ...backend.model_manager import SubModelType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@ -16,33 +16,33 @@ from .baseinvocation import (
|
||||
)
|
||||
|
||||
|
||||
class ModelField(BaseModel):
|
||||
key: str = Field(description="Key of the model")
|
||||
submodel_type: Optional[SubModelType] = Field(description="Submodel type", default=None)
|
||||
class ModelInfo(BaseModel):
|
||||
key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()")
|
||||
submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
||||
|
||||
|
||||
class LoRAField(BaseModel):
|
||||
lora: ModelField = Field(description="Info to load lora model")
|
||||
weight: float = Field(description="Weight to apply to lora model")
|
||||
class LoraInfo(ModelInfo):
|
||||
weight: float = Field(description="Lora's weight which to use when apply to model")
|
||||
|
||||
|
||||
class UNetField(BaseModel):
|
||||
unet: ModelField = Field(description="Info to load unet submodel")
|
||||
scheduler: ModelField = Field(description="Info to load scheduler submodel")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
unet: ModelInfo = Field(description="Info to load unet submodel")
|
||||
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
|
||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
|
||||
|
||||
|
||||
class CLIPField(BaseModel):
|
||||
tokenizer: ModelField = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelField = Field(description="Info to load text_encoder submodel")
|
||||
class ClipField(BaseModel):
|
||||
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
|
||||
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||
|
||||
|
||||
class VAEField(BaseModel):
|
||||
vae: ModelField = Field(description="Info to load vae submodel")
|
||||
class VaeField(BaseModel):
|
||||
# TODO: better naming?
|
||||
vae: ModelInfo = Field(description="Info to load vae submodel")
|
||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||
|
||||
|
||||
@ -57,14 +57,14 @@ class UNetOutput(BaseInvocationOutput):
|
||||
class VAEOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a VAE field"""
|
||||
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation_output("clip_output")
|
||||
class CLIPOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a CLIP field"""
|
||||
|
||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation_output("model_loader_output")
|
||||
@ -74,6 +74,18 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
|
||||
pass
|
||||
|
||||
|
||||
class MainModelField(BaseModel):
|
||||
"""Main model field"""
|
||||
|
||||
key: str = Field(description="Model key")
|
||||
|
||||
|
||||
class LoRAModelField(BaseModel):
|
||||
"""LoRA model field"""
|
||||
|
||||
key: str = Field(description="LoRA model key")
|
||||
|
||||
|
||||
@invocation(
|
||||
"main_model_loader",
|
||||
title="Main Model",
|
||||
@ -84,40 +96,62 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
|
||||
class MainModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
model: ModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
|
||||
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
|
||||
# TODO: precision?
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
# TODO: not found exceptions
|
||||
if not context.models.exists(self.model.key):
|
||||
raise Exception(f"Unknown model {self.model.key}")
|
||||
key = self.model.key
|
||||
|
||||
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
|
||||
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
|
||||
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
# TODO: not found exceptions
|
||||
if not context.models.exists(key):
|
||||
raise Exception(f"Unknown model {key}")
|
||||
|
||||
return ModelLoaderOutput(
|
||||
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
|
||||
vae=VAEField(vae=vae),
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.VAE,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("lora_loader_output")
|
||||
class LoRALoaderOutput(BaseInvocationOutput):
|
||||
class LoraLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.1")
|
||||
class LoRALoaderInvocation(BaseInvocation):
|
||||
class LoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
default=None,
|
||||
@ -125,41 +159,46 @@ class LoRALoaderInvocation(BaseInvocation):
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
clip: Optional[CLIPField] = InputField(
|
||||
clip: Optional[ClipField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LoRALoaderOutput:
|
||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(lora_key):
|
||||
raise Exception(f"Unkown lora: {lora_key}!")
|
||||
|
||||
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to unet')
|
||||
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to unet')
|
||||
|
||||
if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to clip')
|
||||
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to clip')
|
||||
|
||||
output = LoRALoaderOutput()
|
||||
output = LoraLoaderOutput()
|
||||
|
||||
if self.unet is not None:
|
||||
output.unet = self.unet.model_copy(deep=True)
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
if self.clip is not None:
|
||||
output.clip = self.clip.model_copy(deep=True)
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
@ -168,12 +207,12 @@ class LoRALoaderInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@invocation_output("sdxl_lora_loader_output")
|
||||
class SDXLLoRALoaderOutput(BaseInvocationOutput):
|
||||
class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL LoRA Loader Output"""
|
||||
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
|
||||
clip2: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
|
||||
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -183,10 +222,10 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput):
|
||||
category="model",
|
||||
version="1.0.1",
|
||||
)
|
||||
class SDXLLoRALoaderInvocation(BaseInvocation):
|
||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
default=None,
|
||||
@ -194,59 +233,65 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
clip: Optional[CLIPField] = InputField(
|
||||
clip: Optional[ClipField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP 1",
|
||||
)
|
||||
clip2: Optional[CLIPField] = InputField(
|
||||
clip2: Optional[ClipField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP 2",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLLoRALoaderOutput:
|
||||
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(lora_key):
|
||||
raise Exception(f"Unknown lora: {lora_key}!")
|
||||
|
||||
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to unet')
|
||||
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to unet')
|
||||
|
||||
if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to clip')
|
||||
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to clip')
|
||||
|
||||
if self.clip2 is not None and any(lora.lora.key == lora_key for lora in self.clip2.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to clip2')
|
||||
if self.clip2 is not None and any(lora.key == lora_key for lora in self.clip2.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to clip2')
|
||||
|
||||
output = SDXLLoRALoaderOutput()
|
||||
output = SDXLLoraLoaderOutput()
|
||||
|
||||
if self.unet is not None:
|
||||
output.unet = self.unet.model_copy(deep=True)
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
if self.clip is not None:
|
||||
output.clip = self.clip.model_copy(deep=True)
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
if self.clip2 is not None:
|
||||
output.clip2 = self.clip2.model_copy(deep=True)
|
||||
output.clip2 = copy.deepcopy(self.clip2)
|
||||
output.clip2.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
@ -254,11 +299,17 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
|
||||
return output
|
||||
|
||||
|
||||
class VAEModelField(BaseModel):
|
||||
"""Vae model field"""
|
||||
|
||||
key: str = Field(description="Model's key")
|
||||
|
||||
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1")
|
||||
class VAELoaderInvocation(BaseInvocation):
|
||||
class VaeLoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
|
||||
vae_model: ModelField = InputField(
|
||||
vae_model: VAEModelField = InputField(
|
||||
description=FieldDescriptions.vae_model,
|
||||
input=Input.Direct,
|
||||
title="VAE",
|
||||
@ -270,7 +321,7 @@ class VAELoaderInvocation(BaseInvocation):
|
||||
if not context.models.exists(key):
|
||||
raise Exception(f"Unkown vae: {key}!")
|
||||
|
||||
return VAEOutput(vae=VAEField(vae=self.vae_model))
|
||||
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))
|
||||
|
||||
|
||||
@invocation_output("seamless_output")
|
||||
@ -278,7 +329,7 @@ class SeamlessModeOutput(BaseInvocationOutput):
|
||||
"""Modified Seamless Model output"""
|
||||
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
vae: Optional[VAEField] = OutputField(default=None, description=FieldDescriptions.vae, title="VAE")
|
||||
vae: Optional[VaeField] = OutputField(default=None, description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -297,7 +348,7 @@ class SeamlessModeInvocation(BaseInvocation):
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
vae: Optional[VAEField] = InputField(
|
||||
vae: Optional[VaeField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.vae_model,
|
||||
input=Input.Connection,
|
||||
|
@ -8,7 +8,7 @@ from .baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from .model import CLIPField, ModelField, UNetField, VAEField
|
||||
from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField
|
||||
|
||||
|
||||
@invocation_output("sdxl_model_loader_output")
|
||||
@ -16,9 +16,9 @@ class SDXLModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL base model loader output"""
|
||||
|
||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
|
||||
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
|
||||
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation_output("sdxl_refiner_model_loader_output")
|
||||
@ -26,15 +26,15 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL refiner model loader output"""
|
||||
|
||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.1")
|
||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl base model, outputting its submodels."""
|
||||
|
||||
model: ModelField = InputField(
|
||||
model: MainModelField = InputField(
|
||||
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
|
||||
)
|
||||
# TODO: precision?
|
||||
@ -46,19 +46,48 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
if not context.models.exists(model_key):
|
||||
raise Exception(f"Unknown model: {model_key}")
|
||||
|
||||
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
|
||||
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
|
||||
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
return SDXLModelLoaderOutput(
|
||||
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
|
||||
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
|
||||
vae=VAEField(vae=vae),
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
clip2=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Tokenizer2,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.TextEncoder2,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.VAE,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -72,8 +101,10 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||
|
||||
model: ModelField = InputField(
|
||||
description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel
|
||||
model: MainModelField = InputField(
|
||||
description=FieldDescriptions.sdxl_refiner_model,
|
||||
input=Input.Direct,
|
||||
ui_type=UIType.SDXLRefinerModel,
|
||||
)
|
||||
# TODO: precision?
|
||||
|
||||
@ -84,14 +115,34 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
if not context.models.exists(model_key):
|
||||
raise Exception(f"Unknown model: {model_key}")
|
||||
|
||||
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
|
||||
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
|
||||
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
return SDXLRefinerModelLoaderOutput(
|
||||
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
|
||||
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
|
||||
vae=VAEField(vae=vae),
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip2=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Tokenizer2,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.TextEncoder2,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.VAE,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
@ -10,14 +10,17 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import ModelField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
class T2IAdapterModelField(BaseModel):
|
||||
key: str = Field(description="Model record key for the T2I-Adapter model")
|
||||
|
||||
|
||||
class T2IAdapterField(BaseModel):
|
||||
image: ImageField = Field(description="The T2I-Adapter image prompt.")
|
||||
t2i_adapter_model: ModelField = Field(description="The T2I-Adapter model to use.")
|
||||
t2i_adapter_model: T2IAdapterModelField = Field(description="The T2I-Adapter model to use.")
|
||||
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
|
||||
@ -52,7 +55,7 @@ class T2IAdapterInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The IP-Adapter image prompt.")
|
||||
t2i_adapter_model: ModelField = InputField(
|
||||
t2i_adapter_model: T2IAdapterModelField = InputField(
|
||||
description="The T2I-Adapter model.",
|
||||
title="T2I-Adapter Model",
|
||||
input=Input.Direct,
|
||||
|
@ -41,9 +41,8 @@ class InvocationCacheBase(ABC):
|
||||
"""Clears the cache"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def create_key(invocation: BaseInvocation) -> int:
|
||||
def create_key(self, invocation: BaseInvocation) -> int:
|
||||
"""Gets the key for the invocation's cache item"""
|
||||
pass
|
||||
|
||||
|
@ -61,7 +61,9 @@ class MemoryInvocationCache(InvocationCacheBase):
|
||||
self._delete_oldest_access(number_to_delete)
|
||||
self._cache[key] = CachedItem(
|
||||
invocation_output,
|
||||
invocation_output.model_dump_json(warnings=False, exclude_defaults=True, exclude_unset=True),
|
||||
invocation_output.model_dump_json(
|
||||
warnings=False, exclude_defaults=True, exclude_unset=True, include={"type"}
|
||||
),
|
||||
)
|
||||
|
||||
def _delete_oldest_access(self, number_to_delete: int) -> None:
|
||||
@ -79,7 +81,7 @@ class MemoryInvocationCache(InvocationCacheBase):
|
||||
with self._lock:
|
||||
return self._delete(key)
|
||||
|
||||
def clear(self) -> None:
|
||||
def clear(self, *args, **kwargs) -> None:
|
||||
with self._lock:
|
||||
if self._max_cache_size == 0:
|
||||
return
|
||||
|
@ -25,7 +25,6 @@ if TYPE_CHECKING:
|
||||
from .images.images_base import ImageServiceABC
|
||||
from .invocation_cache.invocation_cache_base import InvocationCacheBase
|
||||
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
|
||||
from .model_images.model_images_base import ModelImageFileStorageBase
|
||||
from .model_manager.model_manager_base import ModelManagerServiceBase
|
||||
from .names.names_base import NameServiceBase
|
||||
from .session_processor.session_processor_base import SessionProcessorBase
|
||||
@ -50,7 +49,6 @@ class InvocationServices:
|
||||
image_files: "ImageFileStorageBase",
|
||||
image_records: "ImageRecordStorageBase",
|
||||
logger: "Logger",
|
||||
model_images: "ModelImageFileStorageBase",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
download_queue: "DownloadQueueServiceBase",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
@ -74,7 +72,6 @@ class InvocationServices:
|
||||
self.image_files = image_files
|
||||
self.image_records = image_records
|
||||
self.logger = logger
|
||||
self.model_images = model_images
|
||||
self.model_manager = model_manager
|
||||
self.download_queue = download_queue
|
||||
self.performance_statistics = performance_statistics
|
||||
|
@ -1,33 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
|
||||
class ModelImageFileStorageBase(ABC):
|
||||
"""Low-level service responsible for storing and retrieving image files."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, model_key: str) -> PILImageType:
|
||||
"""Retrieves a model image as PIL Image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, model_key: str) -> Path:
|
||||
"""Gets the internal path to a model image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_url(self, model_key: str) -> str | None:
|
||||
"""Gets the URL to fetch a model image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, image: PILImageType, model_key: str) -> None:
|
||||
"""Saves a model image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, model_key: str) -> None:
|
||||
"""Deletes a model image."""
|
||||
pass
|
@ -1,20 +0,0 @@
|
||||
# TODO: Should these excpetions subclass existing python exceptions?
|
||||
class ModelImageFileNotFoundException(Exception):
|
||||
"""Raised when an image file is not found in storage."""
|
||||
|
||||
def __init__(self, message="Model image file not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ModelImageFileSaveException(Exception):
|
||||
"""Raised when an image cannot be saved."""
|
||||
|
||||
def __init__(self, message="Model image file not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ModelImageFileDeleteException(Exception):
|
||||
"""Raised when an image cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Model image file not deleted"):
|
||||
super().__init__(message)
|
@ -1,79 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as PILImageType
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.util.thumbnails import make_thumbnail
|
||||
|
||||
from .model_images_base import ModelImageFileStorageBase
|
||||
from .model_images_common import (
|
||||
ModelImageFileDeleteException,
|
||||
ModelImageFileNotFoundException,
|
||||
ModelImageFileSaveException,
|
||||
)
|
||||
|
||||
|
||||
class ModelImageFileStorageDisk(ModelImageFileStorageBase):
|
||||
"""Stores images on disk"""
|
||||
|
||||
def __init__(self, model_images_folder: Path):
|
||||
self._model_images_folder = model_images_folder
|
||||
self._validate_storage_folders()
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
|
||||
def get(self, model_key: str) -> PILImageType:
|
||||
try:
|
||||
path = self.get_path(model_key)
|
||||
|
||||
if not self._validate_path(path):
|
||||
raise ModelImageFileNotFoundException
|
||||
|
||||
return Image.open(path)
|
||||
except FileNotFoundError as e:
|
||||
raise ModelImageFileNotFoundException from e
|
||||
|
||||
def save(self, image: PILImageType, model_key: str) -> None:
|
||||
try:
|
||||
self._validate_storage_folders()
|
||||
image_path = self._model_images_folder / (model_key + ".webp")
|
||||
thumbnail = make_thumbnail(image, 256)
|
||||
thumbnail.save(image_path, format="webp")
|
||||
|
||||
except Exception as e:
|
||||
raise ModelImageFileSaveException from e
|
||||
|
||||
def get_path(self, model_key: str) -> Path:
|
||||
path = self._model_images_folder / (model_key + ".webp")
|
||||
|
||||
return path
|
||||
|
||||
def get_url(self, model_key: str) -> str | None:
|
||||
path = self.get_path(model_key)
|
||||
if not self._validate_path(path):
|
||||
return
|
||||
|
||||
return self._invoker.services.urls.get_model_image_url(model_key)
|
||||
|
||||
def delete(self, model_key: str) -> None:
|
||||
try:
|
||||
path = self.get_path(model_key)
|
||||
|
||||
if not self._validate_path(path):
|
||||
raise ModelImageFileNotFoundException
|
||||
|
||||
send2trash(path)
|
||||
|
||||
except Exception as e:
|
||||
raise ModelImageFileDeleteException from e
|
||||
|
||||
def _validate_path(self, path: Path) -> bool:
|
||||
"""Validates the path given for an image."""
|
||||
return path.exists()
|
||||
|
||||
def _validate_storage_folders(self) -> None:
|
||||
"""Checks if the required folders exist and create them if they don't"""
|
||||
self._model_images_folder.mkdir(parents=True, exist_ok=True)
|
@ -4,7 +4,6 @@ import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
@ -281,18 +280,12 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._scan_models_directory()
|
||||
if autoimport := self._app_config.autoimport_dir:
|
||||
self._logger.info("Scanning autoimport directory for new models")
|
||||
installed: List[str] = []
|
||||
# Use ThreadPoolExecutor to scan dirs in parallel
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future_models = [executor.submit(self.scan_directory, self._app_config.root_path / autoimport / cur_model_type.value) for cur_model_type in ModelType]
|
||||
[installed.extend(models.result()) for models in as_completed(future_models)]
|
||||
installed = self.scan_directory(self._app_config.root_path / autoimport)
|
||||
self._logger.info(f"{len(installed)} new models registered")
|
||||
self._logger.info("Model installer (re)initialized")
|
||||
|
||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()}
|
||||
if len([entry for entry in os.scandir(scan_dir) if not entry.name.startswith(".")]) == 0:
|
||||
return []
|
||||
callback = self._scan_install if install else self._scan_register
|
||||
search = ModelSearch(on_model_found=callback, config=self._app_config)
|
||||
self._models_installed.clear()
|
||||
@ -455,10 +448,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self.unregister(key)
|
||||
|
||||
self._logger.info(f"Scanning {self._app_config.models_path} for new and orphaned models")
|
||||
# Use ThreadPoolExecutor to scan dirs in parallel
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future_models = [executor.submit(self.scan_directory, Path(cur_base_model.value, cur_model_type.value)) for cur_base_model in BaseModelType for cur_model_type in ModelType]
|
||||
[installed.update(models.result()) for models in as_completed(future_models)]
|
||||
for cur_base_model in BaseModelType:
|
||||
for cur_model_type in ModelType:
|
||||
models_dir = Path(cur_base_model.value, cur_model_type.value)
|
||||
installed.update(self.scan_directory(models_dir))
|
||||
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
||||
|
||||
def _sync_model_path(self, key: str) -> AnyModelConfig:
|
||||
|
@ -1,11 +1,15 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
|
||||
from ..config import InvokeAIAppConfig
|
||||
from ..download import DownloadQueueServiceBase
|
||||
@ -66,3 +70,32 @@ class ModelManagerServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def stop(self, invoker: Invoker) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model_by_config(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model_by_key(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model_by_attr(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
pass
|
||||
|
@ -1,10 +1,14 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
"""Implementation of ModelManagerServiceBase."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, LoadedModel, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@ -14,7 +18,7 @@ from ..download import DownloadQueueServiceBase
|
||||
from ..events.events_base import EventServiceBase
|
||||
from ..model_install import ModelInstallService, ModelInstallServiceBase
|
||||
from ..model_load import ModelLoadService, ModelLoadServiceBase
|
||||
from ..model_records import ModelRecordServiceBase
|
||||
from ..model_records import ModelRecordServiceBase, UnknownModelException
|
||||
from .model_manager_base import ModelManagerServiceBase
|
||||
|
||||
|
||||
@ -60,6 +64,56 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
if hasattr(service, "stop"):
|
||||
service.stop(invoker)
|
||||
|
||||
def load_model_by_config(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
return self.load.load_model(model_config, submodel_type, context_data)
|
||||
|
||||
def load_model_by_key(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
config = self.store.get_model(key)
|
||||
return self.load.load_model(config, submodel_type, context_data)
|
||||
|
||||
def load_model_by_attr(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
|
||||
|
||||
This is provided for API compatability with the get_model() method
|
||||
in the original model manager. However, note that LoadedModel is
|
||||
not the same as the original ModelInfo that ws returned.
|
||||
|
||||
:param model_name: Name of to be fetched.
|
||||
:param base_model: Base model
|
||||
:param model_type: Type of the model
|
||||
:param submodel: For main (pipeline models), the submodel to fetch
|
||||
:param context: The invocation context.
|
||||
|
||||
Exceptions: UnknownModelException -- model with this key not known
|
||||
NotImplementedException -- a model loader was not provided at initialization time
|
||||
ValueError -- more than one model matches this combination
|
||||
"""
|
||||
configs = self.store.search_by_attr(model_name, base_model, model_type)
|
||||
if len(configs) == 0:
|
||||
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
|
||||
elif len(configs) > 1:
|
||||
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
|
||||
else:
|
||||
return self.load.load_model(configs[0], submodel, context_data)
|
||||
|
||||
@classmethod
|
||||
def build_model_manager(
|
||||
cls,
|
||||
|
@ -79,7 +79,6 @@ class ModelRecordChanges(BaseModelExcludeNull):
|
||||
description="The prediction type of the model.", default=None
|
||||
)
|
||||
upcast_attention: Optional[bool] = Field(description="Whether to upcast attention.", default=None)
|
||||
config_path: Optional[str] = Field(description="Path to config file for model", default=None)
|
||||
|
||||
|
||||
class ModelRecordServiceBase(ABC):
|
||||
@ -130,17 +129,6 @@ class ModelRecordServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
|
||||
"""
|
||||
Retrieve the configuration for the indicated model.
|
||||
|
||||
:param hash: Hash of model config to be fetched.
|
||||
|
||||
Exceptions: UnknownModelException
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_models(
|
||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||
|
@ -203,21 +203,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
return model
|
||||
|
||||
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException("model not found")
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
return model
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Return True if a model with the indicated key exists in the databse.
|
||||
|
@ -1,6 +1,35 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from threading import Event
|
||||
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
|
||||
|
||||
class SessionRunnerBase(ABC):
|
||||
"""
|
||||
Base class for session runner.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def start(self, services: InvocationServices, cancel_event: Event) -> None:
|
||||
"""Starts the session runner"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run(self, queue_item: SessionQueueItem) -> None:
|
||||
"""Runs the session"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def complete(self, queue_item: SessionQueueItem) -> None:
|
||||
"""Completes the session"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run_node(self, node_id: str, queue_item: SessionQueueItem) -> None:
|
||||
"""Runs an already prepared node on the session"""
|
||||
pass
|
||||
|
||||
|
||||
class SessionProcessorBase(ABC):
|
||||
|
@ -2,13 +2,14 @@ import traceback
|
||||
from contextlib import suppress
|
||||
from threading import BoundedSemaphore, Thread
|
||||
from threading import Event as ThreadEvent
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event as FastAPIEvent
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
@ -16,15 +17,164 @@ from invokeai.app.services.shared.invocation_context import InvocationContextDat
|
||||
from invokeai.app.util.profiler import Profiler
|
||||
|
||||
from ..invoker import Invoker
|
||||
from .session_processor_base import SessionProcessorBase
|
||||
from .session_processor_base import SessionProcessorBase, SessionRunnerBase
|
||||
from .session_processor_common import SessionProcessorStatus
|
||||
|
||||
|
||||
class DefaultSessionRunner(SessionRunnerBase):
|
||||
"""Processes a single session's invocations"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
||||
on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
||||
):
|
||||
self.on_before_run_node = on_before_run_node
|
||||
self.on_after_run_node = on_after_run_node
|
||||
|
||||
def start(self, services: InvocationServices, cancel_event: ThreadEvent):
|
||||
"""Start the session runner"""
|
||||
self.services = services
|
||||
self.cancel_event = cancel_event
|
||||
|
||||
def run(self, queue_item: SessionQueueItem):
|
||||
"""Run the graph"""
|
||||
if not queue_item.session:
|
||||
raise ValueError("Queue item has no session")
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
while not (queue_item.session.is_complete() or self.cancel_event.is_set()):
|
||||
# Prepare the next node
|
||||
invocation = queue_item.session.next()
|
||||
if invocation is None:
|
||||
# If there are no more invocations, complete the graph
|
||||
break
|
||||
# Build invocation context (the node-facing API
|
||||
self.run_node(invocation.id, queue_item)
|
||||
self.complete(queue_item)
|
||||
|
||||
def complete(self, queue_item: SessionQueueItem):
|
||||
"""Complete the graph"""
|
||||
self.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
)
|
||||
|
||||
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||
"""Run before a node is executed"""
|
||||
# Send starting event
|
||||
self.services.events.emit_invocation_started(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session_id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
)
|
||||
if self.on_before_run_node is not None:
|
||||
self.on_before_run_node(invocation, queue_item)
|
||||
|
||||
def _on_after_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||
"""Run after a node is executed"""
|
||||
if self.on_after_run_node is not None:
|
||||
self.on_after_run_node(invocation, queue_item)
|
||||
|
||||
def run_node(self, node_id: str, queue_item: SessionQueueItem):
|
||||
"""Run a single node in the graph"""
|
||||
# If this error raises a NodeNotFoundError that's handled by the processor
|
||||
invocation = queue_item.session.execution_graph.get_node(node_id)
|
||||
try:
|
||||
self._on_before_run_node(invocation, queue_item)
|
||||
data = InvocationContextData(
|
||||
invocation=invocation,
|
||||
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
queue_item=queue_item,
|
||||
)
|
||||
|
||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
||||
with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id):
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
services=self.services,
|
||||
cancel_event=self.cancel_event,
|
||||
)
|
||||
|
||||
# Invoke the node
|
||||
outputs = invocation.invoke_internal(context=context, services=self.services)
|
||||
|
||||
# Save outputs and history
|
||||
queue_item.session.complete(invocation.id, outputs)
|
||||
|
||||
self._on_after_run_node(invocation, queue_item)
|
||||
# Send complete event on successful runs
|
||||
self.services.events.emit_invocation_complete(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=data.source_invocation_id,
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
# TODO(MM2): Create an event for this
|
||||
pass
|
||||
except CanceledException:
|
||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||
# be able to cancel them mid-execution.
|
||||
#
|
||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||
# is executed after each step. This step callback checks if the canceled event is set,
|
||||
# then raises a CanceledException to stop execution immediately.
|
||||
#
|
||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||
pass
|
||||
except Exception as e:
|
||||
error = traceback.format_exc()
|
||||
|
||||
# Save error
|
||||
queue_item.session.set_node_error(invocation.id, error)
|
||||
self.services.logger.error(
|
||||
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}"
|
||||
)
|
||||
self.services.logger.error(error)
|
||||
|
||||
# Send error event
|
||||
self.services.events.emit_invocation_error(
|
||||
queue_batch_id=queue_item.session_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
)
|
||||
|
||||
|
||||
class DefaultSessionProcessor(SessionProcessorBase):
|
||||
def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None:
|
||||
"""Processes sessions from the session queue"""
|
||||
|
||||
def __init__(self, session_runner: Union[SessionRunnerBase, None] = None) -> None:
|
||||
super().__init__()
|
||||
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
|
||||
|
||||
def start(
|
||||
self,
|
||||
invoker: Invoker,
|
||||
thread_limit: int = 1,
|
||||
polling_interval: int = 1,
|
||||
on_before_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
|
||||
on_after_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
|
||||
) -> None:
|
||||
self._invoker: Invoker = invoker
|
||||
self._queue_item: Optional[SessionQueueItem] = None
|
||||
self._invocation: Optional[BaseInvocation] = None
|
||||
self.on_before_run_session = on_before_run_session
|
||||
self.on_after_run_session = on_after_run_session
|
||||
|
||||
self._resume_event = ThreadEvent()
|
||||
self._stop_event = ThreadEvent()
|
||||
@ -59,6 +209,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
"cancel_event": self._cancel_event,
|
||||
},
|
||||
)
|
||||
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
@ -117,131 +268,34 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
||||
cancel_event.clear()
|
||||
|
||||
# If we have a on_before_run_session callback, call it
|
||||
if self.on_before_run_session is not None:
|
||||
self.on_before_run_session(self._queue_item)
|
||||
|
||||
# If profiling is enabled, start the profiler
|
||||
if self._profiler is not None:
|
||||
self._profiler.start(profile_id=self._queue_item.session_id)
|
||||
|
||||
# Prepare invocations and take the first
|
||||
self._invocation = self._queue_item.session.next()
|
||||
# Run the graph
|
||||
self.session_runner.run(queue_item=self._queue_item)
|
||||
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
while self._invocation is not None and not cancel_event.is_set():
|
||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
|
||||
|
||||
# Send starting event
|
||||
self._invoker.services.events.emit_invocation_started(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session_id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if self._profiler:
|
||||
profile_path = self._profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
self._invoker.services.performance_statistics.dump_stats(
|
||||
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
|
||||
)
|
||||
|
||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
||||
try:
|
||||
with self._invoker.services.performance_statistics.collect_stats(
|
||||
self._invocation, self._queue_item.session.id
|
||||
):
|
||||
# Build invocation context (the node-facing API)
|
||||
data = InvocationContextData(
|
||||
invocation=self._invocation,
|
||||
source_invocation_id=source_invocation_id,
|
||||
queue_item=self._queue_item,
|
||||
)
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
services=self._invoker.services,
|
||||
cancel_event=self._cancel_event,
|
||||
)
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
|
||||
self._invoker.services.performance_statistics.reset_stats()
|
||||
|
||||
# Invoke the node
|
||||
outputs = self._invocation.invoke_internal(
|
||||
context=context, services=self._invoker.services
|
||||
)
|
||||
|
||||
# Save outputs and history
|
||||
self._queue_item.session.complete(self._invocation.id, outputs)
|
||||
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_invocation_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# TODO(MM2): Create an event for this
|
||||
pass
|
||||
|
||||
except CanceledException:
|
||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||
# be able to cancel them mid-execution.
|
||||
#
|
||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||
# is executed after each step. This step callback checks if the canceled event is set,
|
||||
# then raises a CanceledException to stop execution immediately.
|
||||
#
|
||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
error = traceback.format_exc()
|
||||
|
||||
# Save error
|
||||
self._queue_item.session.set_node_error(self._invocation.id, error)
|
||||
self._invoker.services.logger.error(
|
||||
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
|
||||
)
|
||||
self._invoker.services.logger.error(error)
|
||||
|
||||
# Send error event
|
||||
self._invoker.services.events.emit_invocation_error(
|
||||
queue_batch_id=self._queue_item.session_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
)
|
||||
pass
|
||||
|
||||
# The session is complete if the all invocations are complete or there was an error
|
||||
if self._queue_item.session.is_complete() or cancel_event.is_set():
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
)
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if self._profiler:
|
||||
profile_path = self._profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
self._invoker.services.performance_statistics.dump_stats(
|
||||
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
|
||||
)
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
|
||||
self._invoker.services.performance_statistics.reset_stats()
|
||||
|
||||
# Set the invocation to None to prepare for the next session
|
||||
self._invocation = None
|
||||
else:
|
||||
# Prepare the next invocation
|
||||
self._invocation = self._queue_item.session.next()
|
||||
# If we have a on_after_run_session callback, call it
|
||||
if self.on_after_run_session is not None:
|
||||
self.on_after_run_session(self._queue_item)
|
||||
|
||||
# The session is complete, immediately poll for next session
|
||||
self._queue_item = None
|
||||
@ -275,3 +329,4 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
poll_now_event.clear()
|
||||
self._queue_item = None
|
||||
self._thread_semaphore.release()
|
||||
self._invoker.services.logger.debug("Session processor stopped")
|
||||
|
@ -1,7 +1,7 @@
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from PIL.Image import Image
|
||||
from torch import Tensor
|
||||
@ -13,16 +13,15 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import AnyModelRepoMetadata
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.invocations.model import ModelField
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
|
||||
"""
|
||||
@ -300,25 +299,22 @@ class ConditioningInterface(InvocationContextInterface):
|
||||
|
||||
|
||||
class ModelsInterface(InvocationContextInterface):
|
||||
def exists(self, identifier: Union[str, "ModelField"]) -> bool:
|
||||
def exists(self, key: str) -> bool:
|
||||
"""Checks if a model exists.
|
||||
|
||||
Args:
|
||||
identifier: The key or ModelField representing the model.
|
||||
key: The key of the model.
|
||||
|
||||
Returns:
|
||||
True if the model exists, False if not.
|
||||
"""
|
||||
if isinstance(identifier, str):
|
||||
return self._services.model_manager.store.exists(identifier)
|
||||
return self._services.model_manager.store.exists(key)
|
||||
|
||||
return self._services.model_manager.store.exists(identifier.key)
|
||||
|
||||
def load(self, identifier: Union[str, "ModelField"], submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""Loads a model.
|
||||
|
||||
Args:
|
||||
identifier: The key or ModelField representing the model.
|
||||
key: The key of the model.
|
||||
submodel_type: The submodel of the model to get.
|
||||
|
||||
Returns:
|
||||
@ -328,13 +324,9 @@ class ModelsInterface(InvocationContextInterface):
|
||||
# The model manager emits events as it loads the model. It needs the context data to build
|
||||
# the event payloads.
|
||||
|
||||
if isinstance(identifier, str):
|
||||
model = self._services.model_manager.store.get_model(identifier)
|
||||
return self._services.model_manager.load.load_model(model, submodel_type, self._data)
|
||||
else:
|
||||
_submodel_type = submodel_type or identifier.submodel_type
|
||||
model = self._services.model_manager.store.get_model(identifier.key)
|
||||
return self._services.model_manager.load.load_model(model, _submodel_type, self._data)
|
||||
return self._services.model_manager.load_model_by_key(
|
||||
key=key, submodel_type=submodel_type, context_data=self._data
|
||||
)
|
||||
|
||||
def load_by_attrs(
|
||||
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
||||
@ -351,29 +343,35 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
An object representing the loaded model.
|
||||
"""
|
||||
return self._services.model_manager.load_model_by_attr(
|
||||
model_name=name,
|
||||
base_model=base,
|
||||
model_type=type,
|
||||
submodel=submodel_type,
|
||||
context_data=self._data,
|
||||
)
|
||||
|
||||
configs = self._services.model_manager.store.search_by_attr(model_name=name, base_model=base, model_type=type)
|
||||
if len(configs) == 0:
|
||||
raise UnknownModelException(f"No model found with name {name}, base {base}, and type {type}")
|
||||
|
||||
if len(configs) > 1:
|
||||
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
|
||||
|
||||
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
|
||||
|
||||
def get_config(self, identifier: Union[str, "ModelField"]) -> AnyModelConfig:
|
||||
def get_config(self, key: str) -> AnyModelConfig:
|
||||
"""Gets a model's config.
|
||||
|
||||
Args:
|
||||
identifier: The key or ModelField representing the model.
|
||||
key: The key of the model.
|
||||
|
||||
Returns:
|
||||
The model's config.
|
||||
"""
|
||||
if isinstance(identifier, str):
|
||||
return self._services.model_manager.store.get_model(identifier)
|
||||
return self._services.model_manager.store.get_model(key=key)
|
||||
|
||||
return self._services.model_manager.store.get_model(identifier.key)
|
||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||
"""Gets a model's metadata, if it has any.
|
||||
|
||||
Args:
|
||||
key: The key of the model.
|
||||
|
||||
Returns:
|
||||
The model's metadata, if it has any.
|
||||
"""
|
||||
return self._services.model_manager.store.get_metadata(key=key)
|
||||
|
||||
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
|
||||
"""Searches for models by path.
|
||||
|
@ -17,7 +17,8 @@ class MigrateCallback(Protocol):
|
||||
See :class:`Migration` for an example.
|
||||
"""
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None: ...
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
...
|
||||
|
||||
|
||||
class MigrationError(RuntimeError):
|
||||
|
@ -8,8 +8,3 @@ class UrlServiceBase(ABC):
|
||||
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets the URL for an image or thumbnail."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model_image_url(self, model_key: str) -> str:
|
||||
"""Gets the URL for a model image"""
|
||||
pass
|
||||
|
@ -4,9 +4,8 @@ from .urls_base import UrlServiceBase
|
||||
|
||||
|
||||
class LocalUrlService(UrlServiceBase):
|
||||
def __init__(self, base_url: str = "api/v1", base_url_v2: str = "api/v2"):
|
||||
def __init__(self, base_url: str = "api/v1"):
|
||||
self._base_url = base_url
|
||||
self._base_url_v2 = base_url_v2
|
||||
|
||||
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
image_basename = os.path.basename(image_name)
|
||||
@ -16,6 +15,3 @@ class LocalUrlService(UrlServiceBase):
|
||||
return f"{self._base_url}/images/i/{image_basename}/thumbnail"
|
||||
|
||||
return f"{self._base_url}/images/i/{image_basename}/full"
|
||||
|
||||
def get_model_image_url(self, model_key: str) -> str:
|
||||
return f"{self._base_url_v2}/models/i/{model_key}/image"
|
||||
|
@ -22,7 +22,7 @@ def generate_ti_list(
|
||||
for trigger in extract_ti_triggers_from_prompt(prompt):
|
||||
name_or_key = trigger[1:-1]
|
||||
try:
|
||||
loaded_model = context.models.load(name_or_key)
|
||||
loaded_model = context.models.load(key=name_or_key)
|
||||
model = loaded_model.model
|
||||
assert isinstance(model, TextualInversionModelRaw)
|
||||
assert loaded_model.config.base == base
|
||||
|
@ -19,6 +19,7 @@ from invokeai.app.services.model_install import (
|
||||
ModelInstallService,
|
||||
ModelInstallServiceBase,
|
||||
)
|
||||
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.model_manager import (
|
||||
@ -38,7 +39,7 @@ def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordService
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
image_files = DiskImageFileStorage(f"{app_config.output_path}/images")
|
||||
db = init_db(config=app_config, logger=logger, image_files=image_files)
|
||||
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db)
|
||||
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||
return obj
|
||||
|
||||
|
||||
|
@ -161,7 +161,6 @@ class ModelConfigBase(BaseModel):
|
||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
cover_image: Optional[str] = Field(description="Url for image to preview model", default=None)
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
@ -311,7 +310,7 @@ class IPAdapterConfig(ModelConfigBase):
|
||||
|
||||
|
||||
class CLIPVisionDiffusersConfig(ModelConfigBase):
|
||||
"""Model config for CLIPVision."""
|
||||
"""Model config for ClipVision."""
|
||||
|
||||
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
|
||||
format: Literal[ModelFormat.Diffusers]
|
||||
|
@ -12,8 +12,6 @@ import hashlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, Literal, Optional, Union
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
|
||||
from blake3 import blake3
|
||||
|
||||
@ -107,14 +105,13 @@ class ModelHash:
|
||||
"""
|
||||
model_component_paths = self._get_file_paths(dir, self._file_filter)
|
||||
|
||||
# Use ThreadPoolExecutor to hash files in parallel
|
||||
with ThreadPoolExecutor(min(((os.cpu_count() or 1) + 4), len(model_component_paths))) as executor:
|
||||
future_to_component = {executor.submit(self._hash_file, component): component for component in sorted(model_component_paths)}
|
||||
component_hashes = [future.result() for future in as_completed(future_to_component)]
|
||||
component_hashes: list[str] = []
|
||||
for component in sorted(model_component_paths):
|
||||
component_hashes.append(self._hash_file(component))
|
||||
|
||||
# BLAKE3 to hash the hashes
|
||||
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
|
||||
# for the composite hash
|
||||
composite_hasher = blake3()
|
||||
component_hashes.sort()
|
||||
for h in component_hashes:
|
||||
composite_hasher.update(h.encode("utf-8"))
|
||||
return composite_hasher.hexdigest()
|
||||
@ -132,12 +129,10 @@ class ModelHash:
|
||||
"""
|
||||
|
||||
files: list[Path] = []
|
||||
entries = [entry for entry in os.scandir(model_path.as_posix()) if not entry.name.startswith(".")]
|
||||
dirs = [entry for entry in entries if entry.is_dir()]
|
||||
file_paths = [entry.path for entry in entries if entry.is_file() and file_filter(entry.path)]
|
||||
files.extend([Path(file) for file in file_paths])
|
||||
for dir in dirs:
|
||||
files.extend(ModelHash._get_file_paths(Path(dir.path), file_filter))
|
||||
for root, _dirs, _files in os.walk(model_path):
|
||||
for file in _files:
|
||||
if file_filter(file):
|
||||
files.append(Path(root, file))
|
||||
return files
|
||||
|
||||
@staticmethod
|
||||
@ -166,11 +161,13 @@ class ModelHash:
|
||||
"""
|
||||
|
||||
def hashlib_hasher(file_path: Path) -> str:
|
||||
"""Hashes a file using a hashlib algorithm."""
|
||||
"""Hashes a file using a hashlib algorithm. Uses `memoryview` to avoid reading the entire file into memory."""
|
||||
hasher = hashlib.new(algorithm)
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(8 * 1024), b""):
|
||||
hasher.update(chunk)
|
||||
buffer = bytearray(128 * 1024)
|
||||
mv = memoryview(buffer)
|
||||
with open(file_path, "rb", buffering=0) as f:
|
||||
while n := f.readinto(mv):
|
||||
hasher.update(mv[:n])
|
||||
return hasher.hexdigest()
|
||||
|
||||
return hashlib_hasher
|
||||
|
@ -24,7 +24,7 @@ from .. import ModelLoader, ModelLoaderRegistry
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
|
||||
class LoRALoader(ModelLoader):
|
||||
class LoraLoader(ModelLoader):
|
||||
"""Class to load LoRA models."""
|
||||
|
||||
# We cheat a little bit to get access to the model base
|
||||
|
@ -23,7 +23,7 @@ from .generic_diffusers import GenericDiffusersLoader
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
class VAELoader(GenericDiffusersLoader):
|
||||
class VaeLoader(GenericDiffusersLoader):
|
||||
"""Class to load VAE models."""
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
|
@ -84,9 +84,6 @@ class ProbeBase(object):
|
||||
|
||||
|
||||
class ModelProbe(object):
|
||||
|
||||
hasher = ModelHash()
|
||||
|
||||
PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = {
|
||||
"diffusers": {},
|
||||
"checkpoint": {},
|
||||
@ -160,7 +157,7 @@ class ModelProbe(object):
|
||||
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
||||
)
|
||||
fields["format"] = fields.get("format") or probe.get_format()
|
||||
fields["hash"] = fields.get("hash") or cls.hasher.hash(model_path)
|
||||
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
|
||||
|
||||
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
|
||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||
|
@ -858,9 +858,9 @@ def do_textual_inversion_training(
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
|
||||
orig_embeds_params[index_no_updates]
|
||||
)
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = orig_embeds_params[index_no_updates]
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
|
@ -42,10 +42,9 @@ def install_and_load_model(
|
||||
# If the requested model is already installed, return its LoadedModel
|
||||
with contextlib.suppress(UnknownModelException):
|
||||
# TODO: Replace with wrapper call
|
||||
configs = model_manager.store.search_by_attr(
|
||||
loaded_model: LoadedModel = model_manager.load_model_by_attr(
|
||||
model_name=model_name, base_model=base_model, model_type=model_type
|
||||
)
|
||||
loaded_model: LoadedModel = model_manager.load.load_model(configs[0])
|
||||
return loaded_model
|
||||
|
||||
# Install the requested model.
|
||||
@ -54,7 +53,7 @@ def install_and_load_model(
|
||||
assert job.complete
|
||||
|
||||
try:
|
||||
loaded_model = model_manager.load.load_model(job.config_out)
|
||||
loaded_model = model_manager.load_model_by_config(job.config_out)
|
||||
return loaded_model
|
||||
except UnknownModelException as e:
|
||||
raise Exception(
|
||||
|
@ -20,6 +20,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadQueueService
|
||||
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
|
||||
from invokeai.app.services.model_install import ModelInstallService
|
||||
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.model_manager import (
|
||||
@ -412,7 +413,7 @@ def get_config_store() -> ModelRecordServiceSQL:
|
||||
assert output_path is not None
|
||||
image_files = DiskImageFileStorage(output_path / "images")
|
||||
db = init_db(config=config, logger=InvokeAILogger.get_logger(), image_files=image_files)
|
||||
return ModelRecordServiceSQL(db)
|
||||
return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||
|
||||
|
||||
def get_model_merger(record_store: ModelRecordServiceBase) -> ModelMerger:
|
||||
|
@ -746,7 +746,6 @@
|
||||
"delete": "Delete",
|
||||
"deleteConfig": "Delete Config",
|
||||
"deleteModel": "Delete Model",
|
||||
"deleteModelImage": "Delete Model Image",
|
||||
"deleteMsg1": "Are you sure you want to delete this model from InvokeAI?",
|
||||
"deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.",
|
||||
"description": "Description",
|
||||
@ -787,10 +786,6 @@
|
||||
"modelDeleteFailed": "Failed to delete model",
|
||||
"modelEntryDeleted": "Model Entry Deleted",
|
||||
"modelExists": "Model Exists",
|
||||
"modelImageDeleted": "Model Image Deleted",
|
||||
"modelImageDeleteFailed": "Model Image Delete Failed",
|
||||
"modelImageUpdated": "Model Image Updated",
|
||||
"modelImageUpdateFailed": "Model Image Update Failed",
|
||||
"modelLocation": "Model Location",
|
||||
"modelLocationValidationMsg": "Provide the path to a local folder where your Diffusers Model is stored",
|
||||
"modelManager": "Model Manager",
|
||||
@ -823,7 +818,6 @@
|
||||
"oliveModels": "Olives",
|
||||
"onnxModels": "Onnx",
|
||||
"path": "Path",
|
||||
"pathToConfig": "Path To Config",
|
||||
"pathToCustomConfig": "Path To Custom Config",
|
||||
"pickModelType": "Pick Model Type",
|
||||
"predictionType": "Prediction Type",
|
||||
@ -858,7 +852,6 @@
|
||||
"triggerPhrases": "Trigger Phrases",
|
||||
"typePhraseHere": "Type phrase here",
|
||||
"upcastAttention": "Upcast Attention",
|
||||
"uploadImage": "Upload Image",
|
||||
"updateModel": "Update Model",
|
||||
"useCustomConfig": "Use Custom Config",
|
||||
"useDefaultSettings": "Use Default Settings",
|
||||
@ -951,7 +944,6 @@
|
||||
"doesNotExist": "does not exist",
|
||||
"downloadWorkflow": "Download Workflow JSON",
|
||||
"edge": "Edge",
|
||||
"edit": "Edit",
|
||||
"editMode": "Edit in Workflow Editor",
|
||||
"enum": "Enum",
|
||||
"enumDescription": "Enums are values that may be one of a number of options.",
|
||||
@ -1027,7 +1019,6 @@
|
||||
"nodeTemplate": "Node Template",
|
||||
"nodeType": "Node Type",
|
||||
"noFieldsLinearview": "No fields added to Linear View",
|
||||
"noFieldsViewMode": "This workflow has no selected fields to display. View the full workflow to configure values.",
|
||||
"noFieldType": "No field type",
|
||||
"noImageFoundState": "No initial image found in state",
|
||||
"noMatchingNodes": "No matching nodes",
|
||||
@ -1815,7 +1806,6 @@
|
||||
"cursorPosition": "Cursor Position",
|
||||
"darkenOutsideSelection": "Darken Outside Selection",
|
||||
"discardAll": "Discard All",
|
||||
"discardCurrent": "Discard Current",
|
||||
"downloadAsImage": "Download As Image",
|
||||
"emptyFolder": "Empty Folder",
|
||||
"emptyTempImageFolder": "Empty Temp Image Folder",
|
||||
@ -1825,7 +1815,6 @@
|
||||
"eraseBoundingBox": "Erase Bounding Box",
|
||||
"eraser": "Eraser",
|
||||
"fillBoundingBox": "Fill Bounding Box",
|
||||
"invertBrushSizeScrollDirection": "Invert Scroll for Brush Size",
|
||||
"layer": "Layer",
|
||||
"limitStrokesToBox": "Limit Strokes to Box",
|
||||
"mask": "Mask",
|
||||
|
@ -115,8 +115,7 @@
|
||||
"safetensors": "Safetensors",
|
||||
"ai": "ia",
|
||||
"file": "File",
|
||||
"toResolve": "Da risolvere",
|
||||
"add": "Aggiungi"
|
||||
"toResolve": "Da risolvere"
|
||||
},
|
||||
"gallery": {
|
||||
"generations": "Generazioni",
|
||||
@ -154,12 +153,7 @@
|
||||
"starImage": "Immagine preferita",
|
||||
"dropToUpload": "$t(gallery.drop) per aggiornare",
|
||||
"problemDeletingImagesDesc": "Impossibile eliminare una o più immagini",
|
||||
"problemDeletingImages": "Problema durante l'eliminazione delle immagini",
|
||||
"bulkDownloadRequested": "Preparazione del download",
|
||||
"bulkDownloadRequestedDesc": "La tua richiesta di download è in preparazione. L'operazione potrebbe richiedere alcuni istanti.",
|
||||
"bulkDownloadRequestFailed": "Problema durante la preparazione del download",
|
||||
"bulkDownloadStarting": "Avvio scaricamento",
|
||||
"bulkDownloadFailed": "Scaricamento fallito"
|
||||
"problemDeletingImages": "Problema durante l'eliminazione delle immagini"
|
||||
},
|
||||
"hotkeys": {
|
||||
"keyboardShortcuts": "Tasti di scelta rapida",
|
||||
@ -511,12 +505,12 @@
|
||||
"modelSyncFailed": "Sincronizzazione modello non riuscita",
|
||||
"settings": "Impostazioni",
|
||||
"syncModels": "Sincronizza Modelli",
|
||||
"syncModelsDesc": "Se i tuoi modelli non sono sincronizzati con il back-end, puoi aggiornarli utilizzando questa opzione. Questo è generalmente utile nei casi in cui aggiungi modelli alla cartella principale di InvokeAI dopo l'avvio dell'applicazione.",
|
||||
"syncModelsDesc": "Se i tuoi modelli non sono sincronizzati con il back-end, puoi aggiornarli utilizzando questa opzione. Questo è generalmente utile nei casi in cui aggiorni manualmente il tuo file models.yaml o aggiungi modelli alla cartella principale di InvokeAI dopo l'avvio dell'applicazione.",
|
||||
"loraModels": "LoRA",
|
||||
"oliveModels": "Olive",
|
||||
"onnxModels": "ONNX",
|
||||
"noModels": "Nessun modello trovato",
|
||||
"predictionType": "Tipo di previsione",
|
||||
"predictionType": "Tipo di previsione (per modelli Stable Diffusion 2.x ed alcuni modelli Stable Diffusion 1.x)",
|
||||
"quickAdd": "Aggiunta rapida",
|
||||
"simpleModelDesc": "Fornire un percorso a un modello diffusori locale, un modello checkpoint/safetensor locale, un ID repository HuggingFace o un URL del modello checkpoint/diffusori.",
|
||||
"advanced": "Avanzate",
|
||||
@ -527,34 +521,7 @@
|
||||
"vaePrecision": "Precisione VAE",
|
||||
"noModelSelected": "Nessun modello selezionato",
|
||||
"conversionNotSupported": "Conversione non supportata",
|
||||
"configFile": "File di configurazione",
|
||||
"modelName": "Nome del modello",
|
||||
"modelSettings": "Impostazioni del modello",
|
||||
"advancedImportInfo": "La scheda opzioni avanzate consente la configurazione manuale delle impostazioni del modello principale. Utilizza questa scheda solo se sei sicuro di conoscere il tipo di modello e la configurazione corretti per il modello selezionato.",
|
||||
"addAll": "Aggiungi tutto",
|
||||
"addModels": "Aggiungi modelli",
|
||||
"cancel": "Annulla",
|
||||
"edit": "Modifica",
|
||||
"imageEncoderModelId": "ID modello codificatore di immagini",
|
||||
"importQueue": "Coda di importazione",
|
||||
"modelMetadata": "Metadati del modello",
|
||||
"path": "Percorso",
|
||||
"prune": "Elimina",
|
||||
"pruneTooltip": "Elimina dalla coda le importazioni completate",
|
||||
"removeFromQueue": "Rimuovi dalla coda",
|
||||
"repoVariant": "Variante del repository",
|
||||
"scan": "Scansiona",
|
||||
"scanFolder": "Scansione cartella",
|
||||
"scanResults": "Risultati della scansione",
|
||||
"source": "Sorgente",
|
||||
"upcastAttention": "Eleva l'attenzione",
|
||||
"ztsnrTraining": "Addestramento ZTSNR",
|
||||
"typePhraseHere": "Digita la frase qui",
|
||||
"defaultSettingsSaved": "Impostazioni predefinite salvate",
|
||||
"defaultSettings": "Impostazioni predefinite",
|
||||
"metadata": "Metadati",
|
||||
"useDefaultSettings": "Usa le impostazioni predefinite",
|
||||
"triggerPhrases": "Frasi trigger"
|
||||
"configFile": "File di configurazione"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Immagini",
|
||||
@ -636,8 +603,8 @@
|
||||
"clipSkip": "CLIP Skip",
|
||||
"aspectRatio": "Proporzioni",
|
||||
"maskAdjustmentsHeader": "Regolazioni della maschera",
|
||||
"maskBlur": "Sfocatura maschera",
|
||||
"maskBlurMethod": "Metodo sfocatura maschera",
|
||||
"maskBlur": "Sfocatura",
|
||||
"maskBlurMethod": "Metodo di sfocatura",
|
||||
"seamLowThreshold": "Basso",
|
||||
"seamHighThreshold": "Alto",
|
||||
"coherencePassHeader": "Passaggio di coerenza",
|
||||
@ -694,8 +661,7 @@
|
||||
"setToOptimalSizeTooLarge": "$t(parameters.setToOptimalSize) (potrebbe essere troppo grande)",
|
||||
"boxBlur": "Box",
|
||||
"gaussianBlur": "Gaussian",
|
||||
"remixImage": "Remixa l'immagine",
|
||||
"coherenceEdgeSize": "Dimensione bordo"
|
||||
"remixImage": "Remixa l'immagine"
|
||||
},
|
||||
"settings": {
|
||||
"models": "Modelli",
|
||||
@ -778,8 +744,8 @@
|
||||
"canceled": "Elaborazione annullata",
|
||||
"problemCopyingImageLink": "Impossibile copiare il collegamento dell'immagine",
|
||||
"uploadFailedInvalidUploadDesc": "Deve essere una singola immagine PNG o JPEG",
|
||||
"parameterSet": "{{parameter}} impostato",
|
||||
"parameterNotSet": "{{parameter}} non impostato",
|
||||
"parameterSet": "Parametro impostato",
|
||||
"parameterNotSet": "Parametro non impostato",
|
||||
"nodesLoadedFailed": "Impossibile caricare i nodi",
|
||||
"nodesSaved": "Nodi salvati",
|
||||
"nodesLoaded": "Nodi caricati",
|
||||
@ -832,10 +798,7 @@
|
||||
"problemRetrievingWorkflow": "Problema nel recupero del flusso di lavoro",
|
||||
"resetInitialImage": "Reimposta l'immagine iniziale",
|
||||
"uploadInitialImage": "Carica l'immagine iniziale",
|
||||
"problemDownloadingImage": "Impossibile scaricare l'immagine",
|
||||
"prunedQueue": "Coda ripulita",
|
||||
"modelImportCanceled": "Importazione del modello annullata",
|
||||
"modelImportRemoved": "Importazione del modello rimossa"
|
||||
"problemDownloadingImage": "Impossibile scaricare l'immagine"
|
||||
},
|
||||
"tooltip": {
|
||||
"feature": {
|
||||
@ -913,10 +876,7 @@
|
||||
"antialiasing": "Anti aliasing",
|
||||
"showResultsOn": "Mostra i risultati (attivato)",
|
||||
"showResultsOff": "Mostra i risultati (disattivato)",
|
||||
"saveMask": "Salva $t(unifiedCanvas.mask)",
|
||||
"coherenceModeGaussianBlur": "Sfocatura Gaussiana",
|
||||
"coherenceModeBoxBlur": "Sfocatura Box",
|
||||
"coherenceModeStaged": "Maschera espansa"
|
||||
"saveMask": "Salva $t(unifiedCanvas.mask)"
|
||||
},
|
||||
"accessibility": {
|
||||
"modelSelect": "Seleziona modello",
|
||||
@ -1385,8 +1345,7 @@
|
||||
"allLoRAsAdded": "Tutti i LoRA aggiunti",
|
||||
"defaultVAE": "VAE predefinito",
|
||||
"incompatibleBaseModel": "Modello base incompatibile",
|
||||
"loraAlreadyAdded": "LoRA già aggiunto",
|
||||
"concepts": "Concetti"
|
||||
"loraAlreadyAdded": "LoRA già aggiunto"
|
||||
},
|
||||
"invocationCache": {
|
||||
"disable": "Disabilita",
|
||||
@ -1739,25 +1698,6 @@
|
||||
"paragraphs": [
|
||||
"Valuta le generazioni in modo che siano più simili alle immagini con un punteggio estetico elevato, in base ai dati di addestramento."
|
||||
]
|
||||
},
|
||||
"compositingCoherenceMinDenoise": {
|
||||
"heading": "Livello minimo di riduzione del rumore",
|
||||
"paragraphs": [
|
||||
"Intensità minima di riduzione rumore per la modalità di Coerenza",
|
||||
"L'intensità minima di riduzione del rumore per la regione di coerenza durante l'inpainting o l'outpainting"
|
||||
]
|
||||
},
|
||||
"compositingMaskBlur": {
|
||||
"paragraphs": [
|
||||
"Il raggio di sfocatura della maschera."
|
||||
],
|
||||
"heading": "Sfocatura maschera"
|
||||
},
|
||||
"compositingCoherenceEdgeSize": {
|
||||
"heading": "Dimensione del bordo",
|
||||
"paragraphs": [
|
||||
"La dimensione del bordo del passaggio di coerenza."
|
||||
]
|
||||
}
|
||||
},
|
||||
"sdxl": {
|
||||
@ -1806,12 +1746,7 @@
|
||||
"scheduler": "Campionatore",
|
||||
"recallParameters": "Richiama i parametri",
|
||||
"noRecallParameters": "Nessun parametro da richiamare trovato",
|
||||
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
|
||||
"allPrompts": "Tutti i prompt",
|
||||
"imageDimensions": "Dimensioni dell'immagine",
|
||||
"parameterSet": "Parametro {{parameter}} impostato",
|
||||
"parsingFailed": "Analisi non riuscita",
|
||||
"recallParameter": "Richiama {{label}}"
|
||||
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)"
|
||||
},
|
||||
"hrf": {
|
||||
"enableHrf": "Abilita Correzione Alta Risoluzione",
|
||||
@ -1883,11 +1818,5 @@
|
||||
"image": {
|
||||
"title": "Immagine"
|
||||
}
|
||||
},
|
||||
"prompt": {
|
||||
"compatibleEmbeddings": "Incorporamenti compatibili",
|
||||
"addPromptTrigger": "Aggiungi parola chiave nel prompt",
|
||||
"noPromptTriggers": "Nessuna parola chiave disponibile",
|
||||
"noMatchingTriggers": "Nessuna parola chiave corrispondente"
|
||||
}
|
||||
}
|
||||
|
@ -52,7 +52,7 @@
|
||||
"accept": "Принять",
|
||||
"postprocessing": "Постобработка",
|
||||
"txt2img": "Текст в изображение (txt2img)",
|
||||
"linear": "Линейный вид",
|
||||
"linear": "Линейная обработка",
|
||||
"dontAskMeAgain": "Больше не спрашивать",
|
||||
"areYouSure": "Вы уверены?",
|
||||
"random": "Случайное",
|
||||
@ -117,8 +117,7 @@
|
||||
"toResolve": "Чтоб решить",
|
||||
"copy": "Копировать",
|
||||
"localSystem": "Локальная система",
|
||||
"aboutDesc": "Используя Invoke для работы? Проверьте это:",
|
||||
"add": "Добавить"
|
||||
"aboutDesc": "Используя Invoke для работы? Проверьте это:"
|
||||
},
|
||||
"gallery": {
|
||||
"generations": "Генерации",
|
||||
@ -156,12 +155,7 @@
|
||||
"noImageSelected": "Изображение не выбрано",
|
||||
"setCurrentImage": "Установить как текущее изображение",
|
||||
"starImage": "Добавить в избранное",
|
||||
"dropToUpload": "$t(gallery.drop) чтоб загрузить",
|
||||
"bulkDownloadFailed": "Загрузка не удалась",
|
||||
"bulkDownloadStarting": "Начало загрузки",
|
||||
"bulkDownloadRequested": "Подготовка к скачиванию",
|
||||
"bulkDownloadRequestedDesc": "Ваш запрос на скачивание готовится. Это может занять несколько минут.",
|
||||
"bulkDownloadRequestFailed": "Возникла проблема при подготовке скачивания"
|
||||
"dropToUpload": "$t(gallery.drop) чтоб загрузить"
|
||||
},
|
||||
"hotkeys": {
|
||||
"keyboardShortcuts": "Горячие клавиши",
|
||||
@ -510,7 +504,7 @@
|
||||
"settings": "Настройки",
|
||||
"selectModel": "Выберите модель",
|
||||
"syncModels": "Синхронизация моделей",
|
||||
"syncModelsDesc": "Если ваши модели не синхронизированы с серверной частью, вы можете обновить их с помощью этой опции. Обычно это удобно в тех случаях, когда вы добавляете модели в корневую папку InvokeAI или каталог автоимпорта после загрузки приложения.",
|
||||
"syncModelsDesc": "Если ваши модели не синхронизированы с серверной частью, вы можете обновить их, используя эту опцию. Обычно это удобно в тех случаях, когда вы вручную обновляете свой файл \"models.yaml\" или добавляете модели в корневую папку InvokeAI после загрузки приложения.",
|
||||
"modelUpdateFailed": "Не удалось обновить модель",
|
||||
"modelConversionFailed": "Не удалось сконвертировать модель",
|
||||
"modelsMergeFailed": "Не удалось выполнить слияние моделей",
|
||||
@ -519,7 +513,7 @@
|
||||
"oliveModels": "Модели Olives",
|
||||
"conversionNotSupported": "Преобразование не поддерживается",
|
||||
"noModels": "Нет моделей",
|
||||
"predictionType": "Тип прогноза",
|
||||
"predictionType": "Тип прогноза (для моделей Stable Diffusion 2.x и периодических моделей Stable Diffusion 1.x)",
|
||||
"quickAdd": "Быстрое добавление",
|
||||
"simpleModelDesc": "Укажите путь к локальной модели Diffusers , локальной модели checkpoint / safetensors, идентификатор репозитория HuggingFace или URL-адрес модели контрольной checkpoint / diffusers.",
|
||||
"advanced": "Продвинутый",
|
||||
@ -530,33 +524,7 @@
|
||||
"customConfigFileLocation": "Расположение пользовательского файла конфигурации",
|
||||
"vaePrecision": "Точность VAE",
|
||||
"noModelSelected": "Модель не выбрана",
|
||||
"configFile": "Файл конфигурации",
|
||||
"addAll": "Добавить всё",
|
||||
"addModels": "Добавить модели",
|
||||
"cancel": "Отмена",
|
||||
"defaultSettings": "Стандартные настройки",
|
||||
"importQueue": "Импортировать очередь",
|
||||
"metadata": "Метаданные",
|
||||
"imageEncoderModelId": "ID модели-энкодера изображений",
|
||||
"typePhraseHere": "Введите фразы здесь",
|
||||
"advancedImportInfo": "Вкладка «Дополнительно» позволяет вручную настроить основные параметры модели. Используйте эту вкладку только в том случае, если вы уверены, что знаете правильный тип модели и конфигурацию выбранной модели.",
|
||||
"defaultSettingsSaved": "Стандартные настройки сохранены",
|
||||
"edit": "Редактировать",
|
||||
"path": "Путь",
|
||||
"prune": "Удалить",
|
||||
"pruneTooltip": "Удалить готовые импорты из очереди",
|
||||
"removeFromQueue": "Удалить из очереди",
|
||||
"repoVariant": "Вариант репозитория",
|
||||
"scan": "Сканировать",
|
||||
"scanFolder": "Сканировать папку",
|
||||
"scanResults": "Результаты сканирования",
|
||||
"source": "Источник",
|
||||
"triggerPhrases": "Триггерные фразы",
|
||||
"useDefaultSettings": "Использовать стандартные настройки",
|
||||
"modelMetadata": "Метаданные модели",
|
||||
"modelName": "Название модели",
|
||||
"modelSettings": "Настройки модели",
|
||||
"upcastAttention": "Внимание"
|
||||
"configFile": "Файл конфигурации"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Изображения",
|
||||
@ -623,7 +591,7 @@
|
||||
"hSymmetryStep": "Шаг гор. симметрии",
|
||||
"hidePreview": "Скрыть предпросмотр",
|
||||
"imageToImage": "Изображение в изображение",
|
||||
"denoisingStrength": "Сила зашумления",
|
||||
"denoisingStrength": "Сила шумоподавления",
|
||||
"copyImage": "Скопировать изображение",
|
||||
"showPreview": "Показать предпросмотр",
|
||||
"noiseSettings": "Шум",
|
||||
@ -638,8 +606,8 @@
|
||||
"clipSkip": "CLIP Пропуск",
|
||||
"aspectRatio": "Соотношение",
|
||||
"maskAdjustmentsHeader": "Настройка маски",
|
||||
"maskBlur": "Размытие маски",
|
||||
"maskBlurMethod": "Метод размытия маски",
|
||||
"maskBlur": "Размытие",
|
||||
"maskBlurMethod": "Метод размытия",
|
||||
"seamLowThreshold": "Низкий",
|
||||
"seamHighThreshold": "Высокий",
|
||||
"coherencePassHeader": "Порог Coherence",
|
||||
@ -698,9 +666,7 @@
|
||||
"lockAspectRatio": "Заблокировать соотношение",
|
||||
"boxBlur": "Размытие прямоугольника",
|
||||
"gaussianBlur": "Размытие по Гауссу",
|
||||
"remixImage": "Ремикс изображения",
|
||||
"coherenceMinDenoise": "Мин. шумоподавление",
|
||||
"coherenceEdgeSize": "Размер края"
|
||||
"remixImage": "Ремикс изображения"
|
||||
},
|
||||
"settings": {
|
||||
"models": "Модели",
|
||||
@ -783,8 +749,8 @@
|
||||
"canceled": "Обработка отменена",
|
||||
"problemCopyingImageLink": "Не удалось скопировать ссылку на изображение",
|
||||
"uploadFailedInvalidUploadDesc": "Должно быть одно изображение в формате PNG или JPEG",
|
||||
"parameterNotSet": "Параметр {{parameter}} не задан",
|
||||
"parameterSet": "Параметр {{parameter}} задан",
|
||||
"parameterNotSet": "Параметр не задан",
|
||||
"parameterSet": "Параметр задан",
|
||||
"nodesLoaded": "Узлы загружены",
|
||||
"problemCopyingImage": "Не удается скопировать изображение",
|
||||
"nodesLoadedFailed": "Не удалось загрузить Узлы",
|
||||
@ -837,10 +803,7 @@
|
||||
"problemImportingMask": "Проблема с импортом маски",
|
||||
"problemDownloadingImage": "Не удается скачать изображение",
|
||||
"uploadInitialImage": "Загрузить начальное изображение",
|
||||
"resetInitialImage": "Сбросить начальное изображение",
|
||||
"prunedQueue": "Урезанная очередь",
|
||||
"modelImportCanceled": "Импорт модели отменен",
|
||||
"modelImportRemoved": "Импорт модели удален"
|
||||
"resetInitialImage": "Сбросить начальное изображение"
|
||||
},
|
||||
"tooltip": {
|
||||
"feature": {
|
||||
@ -1182,11 +1145,7 @@
|
||||
"reorderLinearView": "Изменить порядок линейного просмотра",
|
||||
"viewMode": "Использовать в линейном представлении",
|
||||
"editMode": "Открыть в редакторе узлов",
|
||||
"resetToDefaultValue": "Сбросить к стандартному значкнию",
|
||||
"latentsField": "Латенты",
|
||||
"latentsCollectionDescription": "Латенты могут передаваться между узлами.",
|
||||
"latentsPolymorphicDescription": "Латенты могут передаваться между узлами.",
|
||||
"latentsFieldDescription": "Латенты могут передаваться между узлами."
|
||||
"resetToDefaultValue": "Сбросить к стандартному значкнию"
|
||||
},
|
||||
"controlnet": {
|
||||
"amult": "a_mult",
|
||||
@ -1335,8 +1294,7 @@
|
||||
},
|
||||
"paramScheduler": {
|
||||
"paragraphs": [
|
||||
"Планировщик, используемый в процессе генерации.",
|
||||
"Каждый планировщик определяет, как итеративно добавлять шум к изображению или как обновлять образец на основе выходных данных модели."
|
||||
"Планировщик определяет, как итеративно добавлять шум к изображению или как обновлять образец на основе выходных данных модели."
|
||||
],
|
||||
"heading": "Планировщик"
|
||||
},
|
||||
@ -1389,7 +1347,7 @@
|
||||
"compositingCoherenceMode": {
|
||||
"heading": "Режим",
|
||||
"paragraphs": [
|
||||
"Метод, используемый для создания связного изображения с вновь созданной замаскированной областью."
|
||||
"Режим прохождения когерентности."
|
||||
]
|
||||
},
|
||||
"paramSeed": {
|
||||
@ -1407,7 +1365,7 @@
|
||||
},
|
||||
"controlNetBeginEnd": {
|
||||
"paragraphs": [
|
||||
"Часть процесса шумоподавления, к которой будет применен адаптер контроля.",
|
||||
"На каких этапах процесса шумоподавления будет применена ControlNet.",
|
||||
"ControlNet, применяемые в начале процесса, направляют композицию, а ControlNet, применяемые в конце, направляют детали."
|
||||
],
|
||||
"heading": "Процент начала/конца шага"
|
||||
@ -1423,8 +1381,8 @@
|
||||
},
|
||||
"clipSkip": {
|
||||
"paragraphs": [
|
||||
"Сколько слоев модели CLIP пропустить.",
|
||||
"Некоторые модели лучше подходят для использования с CLIP Skip."
|
||||
"Выберите, сколько слоев модели CLIP нужно пропустить.",
|
||||
"Некоторые модели работают лучше с определенными настройками пропуска CLIP."
|
||||
],
|
||||
"heading": "CLIP пропуск"
|
||||
},
|
||||
@ -1521,25 +1479,6 @@
|
||||
"paragraphs": [
|
||||
"Более высокий вес LoRA приведет к большему влиянию на конечное изображение."
|
||||
]
|
||||
},
|
||||
"compositingMaskBlur": {
|
||||
"heading": "Размытие маски",
|
||||
"paragraphs": [
|
||||
"Радиус размытия маски."
|
||||
]
|
||||
},
|
||||
"compositingCoherenceMinDenoise": {
|
||||
"heading": "Минимальное шумоподавление",
|
||||
"paragraphs": [
|
||||
"Минимальный уровень шумоподавления для режима Coherence",
|
||||
"Минимальный уровень шумоподавления для области когерентности при перерисовывании или дорисовке"
|
||||
]
|
||||
},
|
||||
"compositingCoherenceEdgeSize": {
|
||||
"heading": "Размер края",
|
||||
"paragraphs": [
|
||||
"Размер края прохода когерентности."
|
||||
]
|
||||
}
|
||||
},
|
||||
"metadata": {
|
||||
@ -1570,12 +1509,7 @@
|
||||
"steps": "Шаги",
|
||||
"scheduler": "Планировщик",
|
||||
"noRecallParameters": "Параметры для вызова не найдены",
|
||||
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
|
||||
"parameterSet": "Параметр {{parameter}} установлен",
|
||||
"parsingFailed": "Не удалось выполнить синтаксический анализ",
|
||||
"recallParameter": "Отозвать {{label}}",
|
||||
"allPrompts": "Все запросы",
|
||||
"imageDimensions": "Размеры изображения"
|
||||
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)"
|
||||
},
|
||||
"queue": {
|
||||
"status": "Статус",
|
||||
@ -1654,11 +1588,10 @@
|
||||
"denoisingStrength": "Шумоподавление",
|
||||
"refinermodel": "Модель перерисовщик",
|
||||
"posAestheticScore": "Положительная эстетическая оценка",
|
||||
"concatPromptStyle": "Связывание запроса и стиля",
|
||||
"concatPromptStyle": "Объединение запроса и стиля",
|
||||
"loading": "Загрузка...",
|
||||
"steps": "Шаги",
|
||||
"posStylePrompt": "Запрос стиля",
|
||||
"freePromptStyle": "Ручной запрос стиля"
|
||||
"posStylePrompt": "Запрос стиля"
|
||||
},
|
||||
"invocationCache": {
|
||||
"useCache": "Использовать кэш",
|
||||
@ -1745,8 +1678,7 @@
|
||||
"allLoRAsAdded": "Все LoRA добавлены",
|
||||
"defaultVAE": "Стандартное VAE",
|
||||
"incompatibleBaseModel": "Несовместимая базовая модель",
|
||||
"loraAlreadyAdded": "LoRA уже добавлена",
|
||||
"concepts": "Концепты"
|
||||
"loraAlreadyAdded": "LoRA уже добавлена"
|
||||
},
|
||||
"app": {
|
||||
"storeNotInitialized": "Магазин не инициализирован"
|
||||
@ -1764,7 +1696,7 @@
|
||||
},
|
||||
"generation": {
|
||||
"title": "Генерация",
|
||||
"conceptsTab": "LoRA",
|
||||
"conceptsTab": "Концепты",
|
||||
"modelTab": "Модель"
|
||||
},
|
||||
"advanced": {
|
||||
|
@ -5,55 +5,18 @@ import openapiTS from 'openapi-typescript';
|
||||
const OPENAPI_URL = 'http://127.0.0.1:9090/openapi.json';
|
||||
const OUTPUT_FILE = 'src/services/api/schema.ts';
|
||||
|
||||
async function generateTypes(schema) {
|
||||
process.stdout.write(`Generating types ${OUTPUT_FILE}...`);
|
||||
const types = await openapiTS(schema, {
|
||||
async function main() {
|
||||
process.stdout.write(`Generating types "${OPENAPI_URL}" --> "${OUTPUT_FILE}"...`);
|
||||
const types = await openapiTS(OPENAPI_URL, {
|
||||
exportType: true,
|
||||
transform: (schemaObject) => {
|
||||
if ('format' in schemaObject && schemaObject.format === 'binary') {
|
||||
return schemaObject.nullable ? 'Blob | null' : 'Blob';
|
||||
}
|
||||
if (schemaObject.title === 'MetadataField') {
|
||||
// This is `Record<string, never>` by default, but it actually accepts any a dict of any valid JSON value.
|
||||
return 'Record<string, unknown>';
|
||||
}
|
||||
},
|
||||
});
|
||||
fs.writeFileSync(OUTPUT_FILE, types);
|
||||
process.stdout.write(`\nOK!\r\n`);
|
||||
}
|
||||
|
||||
async function main() {
|
||||
const encoding = 'utf-8';
|
||||
|
||||
if (process.stdin.isTTY) {
|
||||
// Handle generating types with an arg (e.g. URL or path to file)
|
||||
if (process.argv.length > 3) {
|
||||
console.error('Usage: typegen.js <openapi.json>');
|
||||
process.exit(1);
|
||||
}
|
||||
if (process.argv[2]) {
|
||||
const schema = new Buffer.from(process.argv[2], encoding);
|
||||
generateTypes(schema);
|
||||
} else {
|
||||
generateTypes(OPENAPI_URL);
|
||||
}
|
||||
} else {
|
||||
// Handle generating types from stdin
|
||||
let schema = '';
|
||||
process.stdin.setEncoding(encoding);
|
||||
|
||||
process.stdin.on('readable', function () {
|
||||
const chunk = process.stdin.read();
|
||||
if (chunk !== null) {
|
||||
schema += chunk;
|
||||
}
|
||||
});
|
||||
|
||||
process.stdin.on('end', function () {
|
||||
generateTypes(JSON.parse(schema));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
main();
|
||||
|
@ -38,7 +38,7 @@ export const addCanvasImageToControlNetListener = (startAppListening: AppStartLi
|
||||
type: 'image/png',
|
||||
}),
|
||||
image_category: 'control',
|
||||
is_intermediate: true,
|
||||
is_intermediate: false,
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
crop_visible: false,
|
||||
postUploadAction: {
|
||||
|
@ -48,7 +48,7 @@ export const addCanvasMaskToControlNetListener = (startAppListening: AppStartLis
|
||||
type: 'image/png',
|
||||
}),
|
||||
image_category: 'mask',
|
||||
is_intermediate: true,
|
||||
is_intermediate: false,
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
crop_visible: false,
|
||||
postUploadAction: {
|
||||
|
@ -101,7 +101,7 @@ export const addEnqueueRequestedCanvasListener = (startAppListening: AppStartLis
|
||||
).unwrap();
|
||||
}
|
||||
|
||||
const graph = await buildCanvasGraph(state, generationMode, canvasInitImage, canvasMaskImage);
|
||||
const graph = buildCanvasGraph(state, generationMode, canvasInitImage, canvasMaskImage);
|
||||
|
||||
log.debug({ graph: parseify(graph) }, `Canvas graph built`);
|
||||
|
||||
|
@ -20,15 +20,15 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
|
||||
if (model && model.base === 'sdxl') {
|
||||
if (action.payload.tabName === 'txt2img') {
|
||||
graph = await buildLinearSDXLTextToImageGraph(state);
|
||||
graph = buildLinearSDXLTextToImageGraph(state);
|
||||
} else {
|
||||
graph = await buildLinearSDXLImageToImageGraph(state);
|
||||
graph = buildLinearSDXLImageToImageGraph(state);
|
||||
}
|
||||
} else {
|
||||
if (action.payload.tabName === 'txt2img') {
|
||||
graph = await buildLinearTextToImageGraph(state);
|
||||
graph = buildLinearTextToImageGraph(state);
|
||||
} else {
|
||||
graph = await buildLinearImageToImageGraph(state);
|
||||
graph = buildLinearImageToImageGraph(state);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -20,7 +20,7 @@ const sx: ChakraProps['sx'] = {
|
||||
'.react-colorful__hue-pointer': colorPickerPointerStyles,
|
||||
'.react-colorful__saturation-pointer': colorPickerPointerStyles,
|
||||
'.react-colorful__alpha-pointer': colorPickerPointerStyles,
|
||||
gap: 5,
|
||||
gap: 2,
|
||||
flexDir: 'column',
|
||||
};
|
||||
|
||||
@ -39,8 +39,8 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
|
||||
<Flex sx={sx}>
|
||||
<RgbaColorPicker color={color} onChange={onChange} style={colorPickerStyles} {...rest} />
|
||||
{withNumberInput && (
|
||||
<Flex gap={5}>
|
||||
<FormControl gap={0}>
|
||||
<Flex>
|
||||
<FormControl>
|
||||
<FormLabel>{t('common.red')}</FormLabel>
|
||||
<CompositeNumberInput
|
||||
value={color.r}
|
||||
@ -52,7 +52,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
|
||||
defaultValue={90}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl gap={0}>
|
||||
<FormControl>
|
||||
<FormLabel>{t('common.green')}</FormLabel>
|
||||
<CompositeNumberInput
|
||||
value={color.g}
|
||||
@ -64,7 +64,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
|
||||
defaultValue={90}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl gap={0}>
|
||||
<FormControl>
|
||||
<FormLabel>{t('common.blue')}</FormLabel>
|
||||
<CompositeNumberInput
|
||||
value={color.b}
|
||||
@ -76,7 +76,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
|
||||
defaultValue={255}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl gap={0}>
|
||||
<FormControl>
|
||||
<FormLabel>{t('common.alpha')}</FormLabel>
|
||||
<CompositeNumberInput
|
||||
value={color.a}
|
||||
|
@ -29,7 +29,7 @@ import { Layer, Stage } from 'react-konva';
|
||||
import IAICanvasBoundingBoxOverlay from './IAICanvasBoundingBoxOverlay';
|
||||
import IAICanvasGrid from './IAICanvasGrid';
|
||||
import IAICanvasIntermediateImage from './IAICanvasIntermediateImage';
|
||||
import IAICanvasMaskCompositor from './IAICanvasMaskCompositor';
|
||||
import IAICanvasMaskCompositer from './IAICanvasMaskCompositer';
|
||||
import IAICanvasMaskLines from './IAICanvasMaskLines';
|
||||
import IAICanvasObjectRenderer from './IAICanvasObjectRenderer';
|
||||
import IAICanvasStagingArea from './IAICanvasStagingArea';
|
||||
@ -176,7 +176,7 @@ const IAICanvas = () => {
|
||||
</Layer>
|
||||
<Layer id="mask" visible={isMaskEnabled && !isStaging} listening={false}>
|
||||
<IAICanvasMaskLines visible={true} listening={false} />
|
||||
<IAICanvasMaskCompositor listening={false} />
|
||||
<IAICanvasMaskCompositer listening={false} />
|
||||
</Layer>
|
||||
<Layer listening={false}>
|
||||
<IAICanvasBoundingBoxOverlay />
|
||||
|
@ -16,9 +16,9 @@ const canvasMaskCompositerSelector = createMemoizedSelector(selectCanvasSlice, (
|
||||
};
|
||||
});
|
||||
|
||||
type IAICanvasMaskCompositorProps = RectConfig;
|
||||
type IAICanvasMaskCompositerProps = RectConfig;
|
||||
|
||||
const IAICanvasMaskCompositor = (props: IAICanvasMaskCompositorProps) => {
|
||||
const IAICanvasMaskCompositer = (props: IAICanvasMaskCompositerProps) => {
|
||||
const { ...rest } = props;
|
||||
|
||||
const { stageCoordinates, stageDimensions } = useAppSelector(canvasMaskCompositerSelector);
|
||||
@ -89,4 +89,4 @@ const IAICanvasMaskCompositor = (props: IAICanvasMaskCompositorProps) => {
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IAICanvasMaskCompositor);
|
||||
export default memo(IAICanvasMaskCompositer);
|
@ -5,7 +5,6 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { stagingAreaImageSaved } from 'features/canvas/store/actions';
|
||||
import {
|
||||
commitStagingAreaImage,
|
||||
discardStagedImage,
|
||||
discardStagedImages,
|
||||
nextStagingAreaImage,
|
||||
prevStagingAreaImage,
|
||||
@ -23,7 +22,6 @@ import {
|
||||
PiEyeBold,
|
||||
PiEyeSlashBold,
|
||||
PiFloppyDiskBold,
|
||||
PiTrashSimpleBold,
|
||||
PiXBold,
|
||||
} from 'react-icons/pi';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
@ -46,40 +44,6 @@ const selector = createMemoizedSelector(selectCanvasSlice, (canvas) => {
|
||||
};
|
||||
});
|
||||
|
||||
const ClearStagingIntermediatesIconButton = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleDiscardStagingArea = useCallback(() => {
|
||||
dispatch(discardStagedImages());
|
||||
}, [dispatch]);
|
||||
|
||||
const handleDiscardStagingImage = useCallback(() => {
|
||||
dispatch(discardStagedImage());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<IconButton
|
||||
tooltip={`${t('unifiedCanvas.discardCurrent')}`}
|
||||
aria-label={t('unifiedCanvas.discardCurrent')}
|
||||
icon={<PiXBold />}
|
||||
onClick={handleDiscardStagingImage}
|
||||
colorScheme="invokeBlue"
|
||||
fontSize={16}
|
||||
/>
|
||||
<IconButton
|
||||
tooltip={`${t('unifiedCanvas.discardAll')} (Esc)`}
|
||||
aria-label={t('unifiedCanvas.discardAll')}
|
||||
icon={<PiTrashSimpleBold />}
|
||||
onClick={handleDiscardStagingArea}
|
||||
colorScheme="error"
|
||||
fontSize={16}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
const IAICanvasStagingAreaToolbar = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { currentStagingAreaImage, shouldShowStagingImage, currentIndex, total } = useAppSelector(selector);
|
||||
@ -221,7 +185,14 @@ const IAICanvasStagingAreaToolbar = () => {
|
||||
onClick={handleSaveToGallery}
|
||||
colorScheme="invokeBlue"
|
||||
/>
|
||||
<ClearStagingIntermediatesIconButton />
|
||||
<IconButton
|
||||
tooltip={`${t('unifiedCanvas.discardAll')} (Esc)`}
|
||||
aria-label={t('unifiedCanvas.discardAll')}
|
||||
icon={<PiXBold />}
|
||||
onClick={handleDiscardStagingArea}
|
||||
colorScheme="error"
|
||||
fontSize={20}
|
||||
/>
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
);
|
||||
|
@ -18,7 +18,6 @@ import {
|
||||
setShouldAutoSave,
|
||||
setShouldCropToBoundingBoxOnSave,
|
||||
setShouldDarkenOutsideBoundingBox,
|
||||
setShouldInvertBrushSizeScrollDirection,
|
||||
setShouldRestrictStrokesToBox,
|
||||
setShouldShowCanvasDebugInfo,
|
||||
setShouldShowGrid,
|
||||
@ -41,7 +40,6 @@ const IAICanvasSettingsButtonPopover = () => {
|
||||
const shouldAutoSave = useAppSelector((s) => s.canvas.shouldAutoSave);
|
||||
const shouldCropToBoundingBoxOnSave = useAppSelector((s) => s.canvas.shouldCropToBoundingBoxOnSave);
|
||||
const shouldDarkenOutsideBoundingBox = useAppSelector((s) => s.canvas.shouldDarkenOutsideBoundingBox);
|
||||
const shouldInvertBrushSizeScrollDirection = useAppSelector((s) => s.canvas.shouldInvertBrushSizeScrollDirection);
|
||||
const shouldShowCanvasDebugInfo = useAppSelector((s) => s.canvas.shouldShowCanvasDebugInfo);
|
||||
const shouldShowGrid = useAppSelector((s) => s.canvas.shouldShowGrid);
|
||||
const shouldShowIntermediates = useAppSelector((s) => s.canvas.shouldShowIntermediates);
|
||||
@ -78,10 +76,6 @@ const IAICanvasSettingsButtonPopover = () => {
|
||||
(e: ChangeEvent<HTMLInputElement>) => dispatch(setShouldDarkenOutsideBoundingBox(e.target.checked)),
|
||||
[dispatch]
|
||||
);
|
||||
const handleChangeShouldInvertBrushSizeScrollDirection = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => dispatch(setShouldInvertBrushSizeScrollDirection(e.target.checked)),
|
||||
[dispatch]
|
||||
);
|
||||
const handleChangeShouldAutoSave = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => dispatch(setShouldAutoSave(e.target.checked)),
|
||||
[dispatch]
|
||||
@ -150,13 +144,6 @@ const IAICanvasSettingsButtonPopover = () => {
|
||||
<FormLabel>{t('unifiedCanvas.limitStrokesToBox')}</FormLabel>
|
||||
<Checkbox isChecked={shouldRestrictStrokesToBox} onChange={handleChangeShouldRestrictStrokesToBox} />
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel>{t('unifiedCanvas.invertBrushSizeScrollDirection')}</FormLabel>
|
||||
<Checkbox
|
||||
isChecked={shouldInvertBrushSizeScrollDirection}
|
||||
onChange={handleChangeShouldInvertBrushSizeScrollDirection}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel>{t('unifiedCanvas.showCanvasDebugInfo')}</FormLabel>
|
||||
<Checkbox isChecked={shouldShowCanvasDebugInfo} onChange={handleChangeShouldShowCanvasDebugInfo} />
|
||||
|
@ -15,7 +15,6 @@ const useCanvasWheel = (stageRef: MutableRefObject<Konva.Stage | null>) => {
|
||||
const stageScale = useAppSelector((s) => s.canvas.stageScale);
|
||||
const isMoveStageKeyHeld = useStore($isMoveStageKeyHeld);
|
||||
const brushSize = useAppSelector((s) => s.canvas.brushSize);
|
||||
const shouldInvertBrushSizeScrollDirection = useAppSelector((s) => s.canvas.shouldInvertBrushSizeScrollDirection);
|
||||
|
||||
return useCallback(
|
||||
(e: KonvaEventObject<WheelEvent>) => {
|
||||
@ -29,16 +28,10 @@ const useCanvasWheel = (stageRef: MutableRefObject<Konva.Stage | null>) => {
|
||||
// checking for ctrl key is pressed or not,
|
||||
// so that brush size can be controlled using ctrl + scroll up/down
|
||||
|
||||
// Invert the delta if the property is set to true
|
||||
let delta = e.evt.deltaY;
|
||||
if (shouldInvertBrushSizeScrollDirection) {
|
||||
delta = -delta;
|
||||
}
|
||||
|
||||
if ($ctrl.get() || $meta.get()) {
|
||||
// This equation was derived by fitting a curve to the desired brush sizes and deltas
|
||||
// see https://github.com/invoke-ai/InvokeAI/pull/5542#issuecomment-1915847565
|
||||
const targetDelta = Math.sign(delta) * 0.7363 * Math.pow(1.0394, brushSize);
|
||||
const targetDelta = Math.sign(e.evt.deltaY) * 0.7363 * Math.pow(1.0394, brushSize);
|
||||
// This needs to be clamped to prevent the delta from getting too large
|
||||
const finalDelta = clamp(targetDelta, -20, 20);
|
||||
// The new brush size is also clamped to prevent it from getting too large or small
|
||||
@ -74,7 +67,7 @@ const useCanvasWheel = (stageRef: MutableRefObject<Konva.Stage | null>) => {
|
||||
dispatch(setStageCoordinates(newCoordinates));
|
||||
}
|
||||
},
|
||||
[stageRef, isMoveStageKeyHeld, brushSize, dispatch, stageScale, shouldInvertBrushSizeScrollDirection]
|
||||
[stageRef, isMoveStageKeyHeld, stageScale, dispatch, brushSize]
|
||||
);
|
||||
};
|
||||
|
||||
|
@ -65,7 +65,6 @@ const initialCanvasState: CanvasState = {
|
||||
shouldAutoSave: false,
|
||||
shouldCropToBoundingBoxOnSave: false,
|
||||
shouldDarkenOutsideBoundingBox: false,
|
||||
shouldInvertBrushSizeScrollDirection: false,
|
||||
shouldLockBoundingBox: false,
|
||||
shouldPreserveMaskedArea: false,
|
||||
shouldRestrictStrokesToBox: true,
|
||||
@ -221,9 +220,6 @@ export const canvasSlice = createSlice({
|
||||
setShouldDarkenOutsideBoundingBox: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldDarkenOutsideBoundingBox = action.payload;
|
||||
},
|
||||
setShouldInvertBrushSizeScrollDirection: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldInvertBrushSizeScrollDirection = action.payload;
|
||||
},
|
||||
clearCanvasHistory: (state) => {
|
||||
state.pastLayerStates = [];
|
||||
state.futureLayerStates = [];
|
||||
@ -292,31 +288,6 @@ export const canvasSlice = createSlice({
|
||||
state.shouldShowStagingImage = true;
|
||||
state.batchIds = [];
|
||||
},
|
||||
discardStagedImage: (state) => {
|
||||
const { images, selectedImageIndex } = state.layerState.stagingArea;
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates.shift();
|
||||
}
|
||||
|
||||
if (!images.length) {
|
||||
return;
|
||||
}
|
||||
|
||||
images.splice(selectedImageIndex, 1);
|
||||
|
||||
if (selectedImageIndex >= images.length) {
|
||||
state.layerState.stagingArea.selectedImageIndex = images.length - 1;
|
||||
}
|
||||
|
||||
if (!images.length) {
|
||||
state.shouldShowStagingImage = false;
|
||||
state.shouldShowStagingOutline = false;
|
||||
}
|
||||
|
||||
state.futureLayerStates = [];
|
||||
},
|
||||
addFillRect: (state) => {
|
||||
const { boundingBoxCoordinates, boundingBoxDimensions, brushColor } = state;
|
||||
|
||||
@ -684,7 +655,6 @@ export const {
|
||||
commitColorPickerColor,
|
||||
commitStagingAreaImage,
|
||||
discardStagedImages,
|
||||
discardStagedImage,
|
||||
nextStagingAreaImage,
|
||||
prevStagingAreaImage,
|
||||
redo,
|
||||
@ -704,7 +674,6 @@ export const {
|
||||
setShouldAutoSave,
|
||||
setShouldCropToBoundingBoxOnSave,
|
||||
setShouldDarkenOutsideBoundingBox,
|
||||
setShouldInvertBrushSizeScrollDirection,
|
||||
setShouldPreserveMaskedArea,
|
||||
setShouldShowBoundingBox,
|
||||
setShouldShowCanvasDebugInfo,
|
||||
|
@ -120,7 +120,6 @@ export interface CanvasState {
|
||||
shouldAutoSave: boolean;
|
||||
shouldCropToBoundingBoxOnSave: boolean;
|
||||
shouldDarkenOutsideBoundingBox: boolean;
|
||||
shouldInvertBrushSizeScrollDirection: boolean;
|
||||
shouldLockBoundingBox: boolean;
|
||||
shouldPreserveMaskedArea: boolean;
|
||||
shouldRestrictStrokesToBox: boolean;
|
||||
|
@ -6,7 +6,7 @@ const AutoAddIcon = () => {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Flex position="absolute" insetInlineEnd={0} top={0} p={1}>
|
||||
<Badge variant="solid" bg="invokeBlue.400">
|
||||
<Badge variant="solid" bg="invokeBlue.500">
|
||||
{t('common.auto')}
|
||||
</Badge>
|
||||
</Flex>
|
||||
|
@ -173,8 +173,8 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
|
||||
w="full"
|
||||
maxW="full"
|
||||
borderBottomRadius="base"
|
||||
bg={isSelected ? 'invokeBlue.400' : 'base.600'}
|
||||
color={isSelected ? 'base.800' : 'base.100'}
|
||||
bg={isSelected ? 'invokeBlue.500' : 'base.600'}
|
||||
color={isSelected ? 'base.50' : 'base.100'}
|
||||
lineHeight="short"
|
||||
fontSize="xs"
|
||||
>
|
||||
@ -193,7 +193,6 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
|
||||
overflow="hidden"
|
||||
textOverflow="ellipsis"
|
||||
noOfLines={1}
|
||||
color="inherit"
|
||||
/>
|
||||
<EditableInput sx={editableInputStyles} />
|
||||
</Editable>
|
||||
|
@ -109,8 +109,8 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
|
||||
w="full"
|
||||
maxW="full"
|
||||
borderBottomRadius="base"
|
||||
bg={isSelected ? 'invokeBlue.400' : 'base.600'}
|
||||
color={isSelected ? 'base.800' : 'base.100'}
|
||||
bg={isSelected ? 'invokeBlue.500' : 'base.600'}
|
||||
color={isSelected ? 'base.50' : 'base.100'}
|
||||
lineHeight="short"
|
||||
fontSize="xs"
|
||||
fontWeight={isSelected ? 'bold' : 'normal'}
|
||||
|
@ -15,7 +15,7 @@ export const MetadataItemView = memo(
|
||||
return (
|
||||
<Flex gap={2}>
|
||||
{onRecall && <RecallButton label={label} onClick={onRecall} isDisabled={isDisabled} />}
|
||||
<Flex direction={direction} fontSize="sm">
|
||||
<Flex direction={direction}>
|
||||
<Text fontWeight="semibold" whiteSpace="pre-wrap" pr={2}>
|
||||
{label}:
|
||||
</Text>
|
||||
|
@ -1,27 +0,0 @@
|
||||
import { Box, Image } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
|
||||
type Props = {
|
||||
image_url?: string;
|
||||
};
|
||||
|
||||
const ModelImage = ({ image_url }: Props) => {
|
||||
if (!image_url) {
|
||||
return <Box height="50px" minWidth="50px" />;
|
||||
}
|
||||
|
||||
return (
|
||||
<Image
|
||||
src={image_url}
|
||||
objectFit="cover"
|
||||
objectPosition="50% 50%"
|
||||
height="50px"
|
||||
width="50px"
|
||||
minHeight="50px"
|
||||
minWidth="50px"
|
||||
borderRadius="base"
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default typedMemo(ModelImage);
|
@ -22,8 +22,6 @@ import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||
import { useDeleteModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
import ModelImage from './ModelImage';
|
||||
|
||||
type ModelListItemProps = {
|
||||
model: AnyModelConfig;
|
||||
};
|
||||
@ -75,7 +73,6 @@ const ModelListItem = (props: ModelListItemProps) => {
|
||||
|
||||
return (
|
||||
<Flex gap={2} alignItems="center" w="full">
|
||||
<ModelImage image_url={model.cover_image || ''} />
|
||||
<Flex
|
||||
as={Button}
|
||||
isChecked={isSelected}
|
||||
|
@ -1,134 +0,0 @@
|
||||
import { Box, Button, IconButton, Image } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { useDropzone } from 'react-dropzone';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowCounterClockwiseBold, PiUploadSimpleBold } from 'react-icons/pi';
|
||||
import { useDeleteModelImageMutation, useUpdateModelImageMutation } from 'services/api/endpoints/models';
|
||||
|
||||
type Props = {
|
||||
model_key: string | null;
|
||||
model_image?: string | null;
|
||||
};
|
||||
|
||||
const ModelImageUpload = ({ model_key, model_image }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const [image, setImage] = useState<string | null>(model_image || null);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [updateModelImage] = useUpdateModelImageMutation();
|
||||
const [deleteModelImage] = useDeleteModelImageMutation();
|
||||
|
||||
const onDropAccepted = useCallback(
|
||||
(files: File[]) => {
|
||||
const file = files[0];
|
||||
|
||||
if (!file || !model_key) {
|
||||
return;
|
||||
}
|
||||
|
||||
updateModelImage({ key: model_key, image: file })
|
||||
.unwrap()
|
||||
.then(() => {
|
||||
setImage(URL.createObjectURL(file));
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelImageUpdated'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelImageUpdateFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
});
|
||||
},
|
||||
[dispatch, model_key, t, updateModelImage]
|
||||
);
|
||||
|
||||
const handleResetImage = useCallback(() => {
|
||||
if (!model_key) {
|
||||
return;
|
||||
}
|
||||
setImage(null);
|
||||
deleteModelImage(model_key)
|
||||
.unwrap()
|
||||
.then(() => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelImageDeleted'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelImageDeleteFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
});
|
||||
}, [dispatch, model_key, t, deleteModelImage]);
|
||||
|
||||
const { getInputProps, getRootProps } = useDropzone({
|
||||
accept: { 'image/png': ['.png'], 'image/jpeg': ['.jpg', '.jpeg', '.png'] },
|
||||
onDropAccepted,
|
||||
noDrag: true,
|
||||
multiple: false,
|
||||
});
|
||||
|
||||
if (image) {
|
||||
return (
|
||||
<Box position="relative">
|
||||
<Image
|
||||
src={image}
|
||||
objectFit="cover"
|
||||
objectPosition="50% 50%"
|
||||
height="100px"
|
||||
width="100px"
|
||||
minWidth="100px"
|
||||
borderRadius="base"
|
||||
/>
|
||||
<IconButton
|
||||
position="absolute"
|
||||
top="1"
|
||||
right="1"
|
||||
onClick={handleResetImage}
|
||||
aria-label={t('modelManager.deleteModelImage')}
|
||||
tooltip={t('modelManager.deleteModelImage')}
|
||||
icon={<PiArrowCounterClockwiseBold size={16} />}
|
||||
size="sm"
|
||||
variant="link"
|
||||
_hover={{ color: 'base.100' }}
|
||||
/>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Button leftIcon={<PiUploadSimpleBold />} {...getRootProps()} pointerEvents="auto">
|
||||
{t('modelManager.uploadImage')}
|
||||
</Button>
|
||||
<input {...getInputProps()} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default typedMemo(ModelImageUpload);
|
@ -4,7 +4,6 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import ModelImageUpload from './Fields/ModelImageUpload';
|
||||
import { ModelMetadata } from './Metadata/ModelMetadata';
|
||||
import { ModelAttrView } from './ModelAttrView';
|
||||
import { ModelEdit } from './ModelEdit';
|
||||
@ -26,22 +25,19 @@ export const Model = () => {
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex alignItems="center" justifyContent="space-between" gap="4" paddingRight="5">
|
||||
<Flex flexDir="column" gap={1} p={2}>
|
||||
<Heading as="h2" fontSize="lg">
|
||||
{data.name}
|
||||
</Heading>
|
||||
<Flex flexDir="column" gap={1} p={2}>
|
||||
<Heading as="h2" fontSize="lg">
|
||||
{data.name}
|
||||
</Heading>
|
||||
|
||||
{data.source && (
|
||||
<Text variant="subtext">
|
||||
{t('modelManager.source')}: {data?.source}
|
||||
</Text>
|
||||
)}
|
||||
<Box mt="4">
|
||||
<ModelAttrView label="Description" value={data.description} />
|
||||
</Box>
|
||||
</Flex>
|
||||
<ModelImageUpload model_key={selectedModelKey} model_image={data.cover_image} />
|
||||
{data.source && (
|
||||
<Text variant="subtext">
|
||||
{t('modelManager.source')}: {data?.source}
|
||||
</Text>
|
||||
)}
|
||||
<Box mt="4">
|
||||
<ModelAttrView label="Description" value={data.description} />
|
||||
</Box>
|
||||
</Flex>
|
||||
|
||||
<Tabs mt="4" h="100%">
|
||||
|
@ -129,7 +129,7 @@ export const ModelEdit = () => {
|
||||
</Flex>
|
||||
|
||||
<Flex flexDir="column" gap={3} mt="4">
|
||||
<Flex gap="4" alignItems="center">
|
||||
<Flex>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
||||
<Textarea fontSize="md" resize="none" {...register('description')} />
|
||||
@ -145,32 +145,20 @@ export const ModelEdit = () => {
|
||||
</FormControl>
|
||||
</Flex>
|
||||
{data.type === 'main' && data.format === 'checkpoint' && (
|
||||
<>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
|
||||
<Input
|
||||
{...register('config_path', {
|
||||
validate: (value) => (value && value.trim().length > 3) || 'Must be at least 3 characters',
|
||||
})}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
<ModelVariantSelect control={control} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
||||
<PredictionTypeSelect control={control} />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
||||
<Checkbox {...register('upcast_attention')} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
</>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
<ModelVariantSelect control={control} />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
||||
<PredictionTypeSelect control={control} />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
||||
<Checkbox {...register('upcast_attention')} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
</form>
|
||||
|
@ -90,26 +90,21 @@ export const ModelView = () => {
|
||||
<ModelAttrView label={t('common.format')} value={modelData.format} />
|
||||
<ModelAttrView label={t('modelManager.path')} value={modelData.path} />
|
||||
</Flex>
|
||||
|
||||
{modelData.type === 'main' && modelData.format === 'diffusers' && modelData.repo_variant && (
|
||||
{modelData.type === 'main' && (
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
|
||||
{modelData.format === 'diffusers' && modelData.repo_variant && (
|
||||
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
|
||||
)}
|
||||
{modelData.format === 'checkpoint' && (
|
||||
<>
|
||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config_path} />
|
||||
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
|
||||
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
|
||||
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
|
||||
</>
|
||||
)}
|
||||
</Flex>
|
||||
)}
|
||||
|
||||
{modelData.type === 'main' && modelData.format === 'checkpoint' && (
|
||||
<>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config_path} />
|
||||
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
|
||||
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
|
||||
{modelData.type === 'ip_adapter' && (
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={modelData.image_encoder_model_id} />
|
||||
@ -117,11 +112,9 @@ export const ModelView = () => {
|
||||
)}
|
||||
</Flex>
|
||||
</Box>
|
||||
{modelData.type === 'main' && (
|
||||
<Box layerStyle="second" borderRadius="base" p={3}>
|
||||
<DefaultSettings />
|
||||
</Box>
|
||||
)}
|
||||
<Box layerStyle="second" borderRadius="base" p={3}>
|
||||
<DefaultSettings />
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
@ -1,55 +0,0 @@
|
||||
import { Button, Flex, Image, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { workflowModeChanged } from 'features/nodes/store/workflowSlice';
|
||||
import InvokeLogoSVG from 'public/assets/images/invoke-symbol-wht-lrg.svg';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const EmptyState = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
dispatch(workflowModeChanged('edit'));
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
userSelect: 'none',
|
||||
}}
|
||||
>
|
||||
<Flex
|
||||
sx={{
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
borderRadius: 'base',
|
||||
flexDir: 'column',
|
||||
gap: 5,
|
||||
maxW: '230px',
|
||||
margin: '0 auto',
|
||||
}}
|
||||
>
|
||||
<Image
|
||||
src={InvokeLogoSVG}
|
||||
alt="invoke-ai-logo"
|
||||
opacity={0.2}
|
||||
mixBlendMode="overlay"
|
||||
w={16}
|
||||
h={16}
|
||||
minW={16}
|
||||
minH={16}
|
||||
userSelect="none"
|
||||
/>
|
||||
<Text textAlign="center" fontSize="md">
|
||||
{t('nodes.noFieldsViewMode')}
|
||||
</Text>
|
||||
<Button colorScheme="invokeBlue" onClick={onClick}>
|
||||
{t('nodes.edit')}
|
||||
</Button>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
@ -7,7 +7,6 @@ import { selectWorkflowSlice } from 'features/nodes/store/workflowSlice';
|
||||
import { t } from 'i18next';
|
||||
import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
|
||||
|
||||
import { EmptyState } from './EmptyState';
|
||||
import WorkflowField from './WorkflowField';
|
||||
|
||||
const selector = createMemoizedSelector(selectWorkflowSlice, (workflow) => {
|
||||
@ -31,7 +30,7 @@ export const WorkflowViewMode = () => {
|
||||
<WorkflowField key={`${nodeId}.${fieldName}`} nodeId={nodeId} fieldName={fieldName} />
|
||||
))
|
||||
) : (
|
||||
<EmptyState />
|
||||
<IAINoContentFallback label={t('nodes.noFieldsLinearview')} icon={null} />
|
||||
)}
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
|
@ -12,17 +12,17 @@ const WorkflowPanel = () => {
|
||||
<Flex layerStyle="first" flexDir="column" w="full" h="full" borderRadius="base" p={2} gap={2}>
|
||||
<Tabs variant="line" display="flex" w="full" h="full" flexDir="column">
|
||||
<TabList>
|
||||
<Tab>{t('common.linear')}</Tab>
|
||||
<Tab>{t('common.details')}</Tab>
|
||||
<Tab>{t('common.linear')}</Tab>
|
||||
<Tab>JSON</Tab>
|
||||
</TabList>
|
||||
|
||||
<TabPanels>
|
||||
<TabPanel>
|
||||
<WorkflowLinearTab />
|
||||
<WorkflowGeneralTab />
|
||||
</TabPanel>
|
||||
<TabPanel>
|
||||
<WorkflowGeneralTab />
|
||||
<WorkflowLinearTab />
|
||||
</TabPanel>
|
||||
<TabPanel>
|
||||
<WorkflowJSONTab />
|
||||
|
@ -55,22 +55,8 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
|
||||
|
||||
// #region Model-related schemas
|
||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
||||
const zSubModelType = z.enum([
|
||||
'unet',
|
||||
'text_encoder',
|
||||
'text_encoder_2',
|
||||
'tokenizer',
|
||||
'tokenizer_2',
|
||||
'vae',
|
||||
'vae_decoder',
|
||||
'vae_encoder',
|
||||
'scheduler',
|
||||
'safety_checker',
|
||||
]);
|
||||
|
||||
const zModelIdentifier = z.object({
|
||||
key: z.string().min(1),
|
||||
submodel_type: zSubModelType.nullish(),
|
||||
});
|
||||
export const isModelIdentifier = (field: unknown): field is ModelIdentifier =>
|
||||
zModelIdentifier.safeParse(field).success;
|
||||
|
@ -1,22 +1,18 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { selectValidControlNets } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { omit } from 'lodash-es';
|
||||
import type {
|
||||
CollectInvocation,
|
||||
ControlField,
|
||||
ControlNetInvocation,
|
||||
CoreMetadataInvocation,
|
||||
NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
import { isControlNetModelConfig } from 'services/api/types';
|
||||
|
||||
import { CONTROL_NET_COLLECT } from './constants';
|
||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||
import { upsertMetadata } from './metadata';
|
||||
|
||||
export const addControlNetToLinearGraph = async (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string
|
||||
): Promise<void> => {
|
||||
export const addControlNetToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
|
||||
const validControlNets = selectValidControlNets(state.controlAdapters).filter(
|
||||
(ca) => ca.model?.base === state.generation.model?.base
|
||||
);
|
||||
@ -43,7 +39,7 @@ export const addControlNetToLinearGraph = async (
|
||||
},
|
||||
});
|
||||
|
||||
validControlNets.forEach(async (controlNet) => {
|
||||
validControlNets.forEach((controlNet) => {
|
||||
if (!controlNet.model) {
|
||||
return;
|
||||
}
|
||||
@ -89,17 +85,7 @@ export const addControlNetToLinearGraph = async (
|
||||
|
||||
graph.nodes[controlNetNode.id] = controlNetNode as ControlNetInvocation;
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isControlNetModelConfig);
|
||||
|
||||
controlNetMetadata.push({
|
||||
control_model: getModelMetadataField(modelConfig),
|
||||
control_weight: weight,
|
||||
control_mode: controlMode,
|
||||
begin_step_percent: beginStepPct,
|
||||
end_step_percent: endStepPct,
|
||||
resize_mode: resizeMode,
|
||||
image: controlNetNode.image,
|
||||
});
|
||||
controlNetMetadata.push(omit(controlNetNode, ['id', 'type', 'is_intermediate']) as ControlField);
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: controlNetNode.id, field: 'control' },
|
||||
|
@ -1,22 +1,18 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { omit } from 'lodash-es';
|
||||
import type {
|
||||
CollectInvocation,
|
||||
CoreMetadataInvocation,
|
||||
IPAdapterInvocation,
|
||||
IPAdapterMetadataField,
|
||||
NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
import { isIPAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
import { IP_ADAPTER_COLLECT } from './constants';
|
||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||
import { upsertMetadata } from './metadata';
|
||||
|
||||
export const addIPAdapterToLinearGraph = async (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string
|
||||
): Promise<void> => {
|
||||
export const addIPAdapterToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
|
||||
const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter(
|
||||
(ca) => ca.model?.base === state.generation.model?.base
|
||||
);
|
||||
@ -39,7 +35,7 @@ export const addIPAdapterToLinearGraph = async (
|
||||
|
||||
const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = [];
|
||||
|
||||
validIPAdapters.forEach(async (ipAdapter) => {
|
||||
validIPAdapters.forEach((ipAdapter) => {
|
||||
if (!ipAdapter.model) {
|
||||
return;
|
||||
}
|
||||
@ -62,17 +58,9 @@ export const addIPAdapterToLinearGraph = async (
|
||||
return;
|
||||
}
|
||||
|
||||
graph.nodes[ipAdapterNode.id] = ipAdapterNode;
|
||||
graph.nodes[ipAdapterNode.id] = ipAdapterNode as IPAdapterInvocation;
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isIPAdapterModelConfig);
|
||||
|
||||
ipAdapterMetdata.push({
|
||||
weight: weight,
|
||||
ip_adapter_model: getModelMetadataField(modelConfig),
|
||||
begin_step_percent: beginStepPct,
|
||||
end_step_percent: endStepPct,
|
||||
image: ipAdapterNode.image,
|
||||
});
|
||||
ipAdapterMetdata.push(omit(ipAdapterNode, ['id', 'type', 'is_intermediate']) as IPAdapterMetadataField);
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
|
||||
|
@ -1,22 +1,16 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { filter, size } from 'lodash-es';
|
||||
import {
|
||||
type CoreMetadataInvocation,
|
||||
isLoRAModelConfig,
|
||||
type LoRALoaderInvocation,
|
||||
type NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
import type { CoreMetadataInvocation, LoraLoaderInvocation, NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { CLIP_SKIP, LORA_LOADER, MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, POSITIVE_CONDITIONING } from './constants';
|
||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||
import { upsertMetadata } from './metadata';
|
||||
|
||||
export const addLoRAsToGraph = async (
|
||||
export const addLoRAsToGraph = (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string,
|
||||
modelLoaderNodeId: string = MAIN_MODEL_LOADER
|
||||
): Promise<void> => {
|
||||
): void => {
|
||||
/**
|
||||
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
|
||||
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
|
||||
@ -45,12 +39,12 @@ export const addLoRAsToGraph = async (
|
||||
let currentLoraIndex = 0;
|
||||
const loraMetadata: CoreMetadataInvocation['loras'] = [];
|
||||
|
||||
enabledLoRAs.forEach(async (lora) => {
|
||||
enabledLoRAs.forEach((lora) => {
|
||||
const { weight } = lora;
|
||||
const { key } = lora.model;
|
||||
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
|
||||
|
||||
const loraLoaderNode: LoRALoaderInvocation = {
|
||||
const loraLoaderNode: LoraLoaderInvocation = {
|
||||
type: 'lora_loader',
|
||||
id: currentLoraNodeId,
|
||||
is_intermediate: true,
|
||||
@ -58,10 +52,8 @@ export const addLoRAsToGraph = async (
|
||||
weight,
|
||||
};
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
|
||||
|
||||
loraMetadata.push({
|
||||
model: getModelMetadataField(modelConfig),
|
||||
model: { key },
|
||||
weight,
|
||||
});
|
||||
|
||||
|
@ -1,12 +1,6 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { filter, size } from 'lodash-es';
|
||||
import {
|
||||
type CoreMetadataInvocation,
|
||||
isLoRAModelConfig,
|
||||
type NonNullableGraph,
|
||||
type SDXLLoRALoaderInvocation,
|
||||
} from 'services/api/types';
|
||||
import type { CoreMetadataInvocation, NonNullableGraph, SDXLLoraLoaderInvocation } from 'services/api/types';
|
||||
|
||||
import {
|
||||
LORA_LOADER,
|
||||
@ -16,14 +10,14 @@ import {
|
||||
SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||
import { upsertMetadata } from './metadata';
|
||||
|
||||
export const addSDXLLoRAsToGraph = async (
|
||||
export const addSDXLLoRAsToGraph = (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string,
|
||||
modelLoaderNodeId: string = SDXL_MODEL_LOADER
|
||||
): Promise<void> => {
|
||||
): void => {
|
||||
/**
|
||||
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
|
||||
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
|
||||
@ -61,12 +55,12 @@ export const addSDXLLoRAsToGraph = async (
|
||||
let lastLoraNodeId = '';
|
||||
let currentLoraIndex = 0;
|
||||
|
||||
enabledLoRAs.forEach(async (lora) => {
|
||||
enabledLoRAs.forEach((lora) => {
|
||||
const { weight } = lora;
|
||||
const { key } = lora.model;
|
||||
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
|
||||
|
||||
const loraLoaderNode: SDXLLoRALoaderInvocation = {
|
||||
const loraLoaderNode: SDXLLoraLoaderInvocation = {
|
||||
type: 'sdxl_lora_loader',
|
||||
id: currentLoraNodeId,
|
||||
is_intermediate: true,
|
||||
@ -74,9 +68,7 @@ export const addSDXLLoRAsToGraph = async (
|
||||
weight,
|
||||
};
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
|
||||
|
||||
loraMetadata.push({ model: getModelMetadataField(modelConfig), weight });
|
||||
loraMetadata.push({ model: { key }, weight });
|
||||
|
||||
// add to graph
|
||||
graph.nodes[currentLoraNodeId] = loraLoaderNode;
|
||||
|
@ -1,11 +1,9 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import {
|
||||
type CreateDenoiseMaskInvocation,
|
||||
type ImageDTO,
|
||||
isRefinerMainModelModelConfig,
|
||||
type NonNullableGraph,
|
||||
type SeamlessModeInvocation,
|
||||
import type {
|
||||
CreateDenoiseMaskInvocation,
|
||||
ImageDTO,
|
||||
NonNullableGraph,
|
||||
SeamlessModeInvocation,
|
||||
} from 'services/api/types';
|
||||
|
||||
import {
|
||||
@ -27,16 +25,16 @@ import {
|
||||
SDXL_REFINER_SEAMLESS,
|
||||
} from './constants';
|
||||
import { getSDXLStylePrompts } from './graphBuilderUtils';
|
||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||
import { upsertMetadata } from './metadata';
|
||||
|
||||
export const addSDXLRefinerToGraph = async (
|
||||
export const addSDXLRefinerToGraph = (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string,
|
||||
modelLoaderNodeId?: string,
|
||||
canvasInitImage?: ImageDTO,
|
||||
canvasMaskImage?: ImageDTO
|
||||
): Promise<void> => {
|
||||
): void => {
|
||||
const {
|
||||
refinerModel,
|
||||
refinerPositiveAestheticScore,
|
||||
@ -57,10 +55,9 @@ export const addSDXLRefinerToGraph = async (
|
||||
const fp32 = vaePrecision === 'fp32';
|
||||
|
||||
const isUsingScaledDimensions = ['auto', 'manual'].includes(boundingBoxScaleMethod);
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(refinerModel.key, isRefinerMainModelModelConfig);
|
||||
|
||||
upsertMetadata(graph, {
|
||||
refiner_model: getModelMetadataField(modelConfig),
|
||||
refiner_model: refinerModel,
|
||||
refiner_positive_aesthetic_score: refinerPositiveAestheticScore,
|
||||
refiner_negative_aesthetic_score: refinerNegativeAestheticScore,
|
||||
refiner_cfg_scale: refinerCFGScale,
|
||||
|
@ -1,22 +1,18 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { selectValidT2IAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import {
|
||||
type CollectInvocation,
|
||||
type CoreMetadataInvocation,
|
||||
isT2IAdapterModelConfig,
|
||||
type NonNullableGraph,
|
||||
type T2IAdapterInvocation,
|
||||
import { omit } from 'lodash-es';
|
||||
import type {
|
||||
CollectInvocation,
|
||||
CoreMetadataInvocation,
|
||||
NonNullableGraph,
|
||||
T2IAdapterField,
|
||||
T2IAdapterInvocation,
|
||||
} from 'services/api/types';
|
||||
|
||||
import { T2I_ADAPTER_COLLECT } from './constants';
|
||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||
import { upsertMetadata } from './metadata';
|
||||
|
||||
export const addT2IAdaptersToLinearGraph = async (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string
|
||||
): Promise<void> => {
|
||||
export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
|
||||
const validT2IAdapters = selectValidT2IAdapters(state.controlAdapters).filter(
|
||||
(ca) => ca.model?.base === state.generation.model?.base
|
||||
);
|
||||
@ -37,9 +33,9 @@ export const addT2IAdaptersToLinearGraph = async (
|
||||
},
|
||||
});
|
||||
|
||||
const t2iAdapterMetadata: CoreMetadataInvocation['t2iAdapters'] = [];
|
||||
const t2iAdapterMetdata: CoreMetadataInvocation['t2iAdapters'] = [];
|
||||
|
||||
validT2IAdapters.forEach(async (t2iAdapter) => {
|
||||
validT2IAdapters.forEach((t2iAdapter) => {
|
||||
if (!t2iAdapter.model) {
|
||||
return;
|
||||
}
|
||||
@ -81,18 +77,9 @@ export const addT2IAdaptersToLinearGraph = async (
|
||||
return;
|
||||
}
|
||||
|
||||
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode;
|
||||
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode as T2IAdapterInvocation;
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(t2iAdapter.model.key, isT2IAdapterModelConfig);
|
||||
|
||||
t2iAdapterMetadata.push({
|
||||
begin_step_percent: beginStepPct,
|
||||
end_step_percent: endStepPct,
|
||||
resize_mode: resizeMode,
|
||||
t2i_adapter_model: getModelMetadataField(modelConfig),
|
||||
weight: weight,
|
||||
image: t2iAdapterNode.image,
|
||||
});
|
||||
t2iAdapterMetdata.push(omit(t2iAdapterNode, ['id', 'type', 'is_intermediate']) as T2IAdapterField);
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },
|
||||
@ -103,6 +90,6 @@ export const addT2IAdaptersToLinearGraph = async (
|
||||
});
|
||||
});
|
||||
|
||||
upsertMetadata(graph, { t2iAdapters: t2iAdapterMetadata });
|
||||
upsertMetadata(graph, { t2iAdapters: t2iAdapterMetdata });
|
||||
}
|
||||
};
|
||||
|
@ -1,7 +1,5 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import type { ModelMetadataField, NonNullableGraph } from 'services/api/types';
|
||||
import { isVAEModelConfig } from 'services/api/types';
|
||||
import type { NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import {
|
||||
CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
||||
@ -25,13 +23,13 @@ import {
|
||||
TEXT_TO_IMAGE_GRAPH,
|
||||
VAE_LOADER,
|
||||
} from './constants';
|
||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||
import { upsertMetadata } from './metadata';
|
||||
|
||||
export const addVAEToGraph = async (
|
||||
export const addVAEToGraph = (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
modelLoaderNodeId: string = MAIN_MODEL_LOADER
|
||||
): Promise<void> => {
|
||||
): void => {
|
||||
const { vae, seamlessXAxis, seamlessYAxis } = state.generation;
|
||||
const { boundingBoxScaleMethod } = state.canvas;
|
||||
const { refinerModel } = state.sdxl;
|
||||
@ -151,8 +149,6 @@ export const addVAEToGraph = async (
|
||||
}
|
||||
|
||||
if (vae) {
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(vae.key, isVAEModelConfig);
|
||||
const vaeMetadata: ModelMetadataField = getModelMetadataField(modelConfig);
|
||||
upsertMetadata(graph, { vae: vaeMetadata });
|
||||
upsertMetadata(graph, { vae });
|
||||
}
|
||||
};
|
||||
|
@ -10,46 +10,46 @@ import { buildCanvasSDXLOutpaintGraph } from './buildCanvasSDXLOutpaintGraph';
|
||||
import { buildCanvasSDXLTextToImageGraph } from './buildCanvasSDXLTextToImageGraph';
|
||||
import { buildCanvasTextToImageGraph } from './buildCanvasTextToImageGraph';
|
||||
|
||||
export const buildCanvasGraph = async (
|
||||
export const buildCanvasGraph = (
|
||||
state: RootState,
|
||||
generationMode: 'txt2img' | 'img2img' | 'inpaint' | 'outpaint',
|
||||
canvasInitImage: ImageDTO | undefined,
|
||||
canvasMaskImage: ImageDTO | undefined
|
||||
): Promise<NonNullableGraph> => {
|
||||
) => {
|
||||
let graph: NonNullableGraph;
|
||||
|
||||
if (generationMode === 'txt2img') {
|
||||
if (state.generation.model && state.generation.model.base === 'sdxl') {
|
||||
graph = await buildCanvasSDXLTextToImageGraph(state);
|
||||
graph = buildCanvasSDXLTextToImageGraph(state);
|
||||
} else {
|
||||
graph = await buildCanvasTextToImageGraph(state);
|
||||
graph = buildCanvasTextToImageGraph(state);
|
||||
}
|
||||
} else if (generationMode === 'img2img') {
|
||||
if (!canvasInitImage) {
|
||||
throw new Error('Missing canvas init image');
|
||||
}
|
||||
if (state.generation.model && state.generation.model.base === 'sdxl') {
|
||||
graph = await buildCanvasSDXLImageToImageGraph(state, canvasInitImage);
|
||||
graph = buildCanvasSDXLImageToImageGraph(state, canvasInitImage);
|
||||
} else {
|
||||
graph = await buildCanvasImageToImageGraph(state, canvasInitImage);
|
||||
graph = buildCanvasImageToImageGraph(state, canvasInitImage);
|
||||
}
|
||||
} else if (generationMode === 'inpaint') {
|
||||
if (!canvasInitImage || !canvasMaskImage) {
|
||||
throw new Error('Missing canvas init and mask images');
|
||||
}
|
||||
if (state.generation.model && state.generation.model.base === 'sdxl') {
|
||||
graph = await buildCanvasSDXLInpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
graph = buildCanvasSDXLInpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
} else {
|
||||
graph = await buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
graph = buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
}
|
||||
} else {
|
||||
if (!canvasInitImage) {
|
||||
throw new Error('Missing canvas init image');
|
||||
}
|
||||
if (state.generation.model && state.generation.model.base === 'sdxl') {
|
||||
graph = await buildCanvasSDXLOutpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
graph = buildCanvasSDXLOutpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
} else {
|
||||
graph = await buildCanvasOutpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
graph = buildCanvasOutpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,13 +1,7 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import {
|
||||
type ImageDTO,
|
||||
type ImageToLatentsInvocation,
|
||||
isNonRefinerMainModelConfig,
|
||||
type NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
import type { ImageDTO, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
@ -31,15 +25,12 @@ import {
|
||||
POSITIVE_CONDITIONING,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Image to Image graph.
|
||||
*/
|
||||
export const buildCanvasImageToImageGraph = async (
|
||||
state: RootState,
|
||||
initialImage: ImageDTO
|
||||
): Promise<NonNullableGraph> => {
|
||||
export const buildCanvasImageToImageGraph = (state: RootState, initialImage: ImageDTO): NonNullableGraph => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -315,8 +306,6 @@ export const buildCanvasImageToImageGraph = async (
|
||||
});
|
||||
}
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
@ -327,7 +316,7 @@ export const buildCanvasImageToImageGraph = async (
|
||||
height: !isUsingScaledDimensions ? height : scaledBoundingBoxDimensions.height,
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
model: getModelMetadataField(modelConfig),
|
||||
model,
|
||||
seed,
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
@ -346,17 +335,17 @@ export const buildCanvasImageToImageGraph = async (
|
||||
}
|
||||
|
||||
// add LoRA support
|
||||
await addLoRAsToGraph(state, graph, DENOISE_LATENTS);
|
||||
addLoRAsToGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// optionally add custom VAE
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -40,11 +40,11 @@ import { getBoardField, getIsIntermediate } from './graphBuilderUtils';
|
||||
/**
|
||||
* Builds the Canvas tab's Inpaint graph.
|
||||
*/
|
||||
export const buildCanvasInpaintGraph = async (
|
||||
export const buildCanvasInpaintGraph = (
|
||||
state: RootState,
|
||||
canvasInitImage: ImageDTO,
|
||||
canvasMaskImage: ImageDTO
|
||||
): Promise<NonNullableGraph> => {
|
||||
): NonNullableGraph => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -414,17 +414,17 @@ export const buildCanvasInpaintGraph = async (
|
||||
}
|
||||
|
||||
// Add VAE
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
// must add before watermarker!
|
||||
|
@ -44,11 +44,11 @@ import { getBoardField, getIsIntermediate } from './graphBuilderUtils';
|
||||
/**
|
||||
* Builds the Canvas tab's Outpaint graph.
|
||||
*/
|
||||
export const buildCanvasOutpaintGraph = async (
|
||||
export const buildCanvasOutpaintGraph = (
|
||||
state: RootState,
|
||||
canvasInitImage: ImageDTO,
|
||||
canvasMaskImage?: ImageDTO
|
||||
): Promise<NonNullableGraph> => {
|
||||
): NonNullableGraph => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -545,18 +545,18 @@ export const buildCanvasOutpaintGraph = async (
|
||||
}
|
||||
|
||||
// Add VAE
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -1,12 +1,6 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import {
|
||||
type ImageDTO,
|
||||
type ImageToLatentsInvocation,
|
||||
isNonRefinerMainModelConfig,
|
||||
type NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
import type { ImageDTO, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
@ -32,15 +26,12 @@ import {
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
||||
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Image to Image graph.
|
||||
*/
|
||||
export const buildCanvasSDXLImageToImageGraph = async (
|
||||
state: RootState,
|
||||
initialImage: ImageDTO
|
||||
): Promise<NonNullableGraph> => {
|
||||
export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage: ImageDTO): NonNullableGraph => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -316,8 +307,6 @@ export const buildCanvasSDXLImageToImageGraph = async (
|
||||
});
|
||||
}
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
@ -328,7 +317,7 @@ export const buildCanvasSDXLImageToImageGraph = async (
|
||||
height: !isUsingScaledDimensions ? height : scaledBoundingBoxDimensions.height,
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
model: getModelMetadataField(modelConfig),
|
||||
model,
|
||||
seed,
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
@ -349,24 +338,24 @@ export const buildCanvasSDXLImageToImageGraph = async (
|
||||
|
||||
// Add Refiner if enabled
|
||||
if (refinerModel) {
|
||||
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||
}
|
||||
}
|
||||
|
||||
// optionally add custom VAE
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -41,11 +41,11 @@ import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBu
|
||||
/**
|
||||
* Builds the Canvas tab's Inpaint graph.
|
||||
*/
|
||||
export const buildCanvasSDXLInpaintGraph = async (
|
||||
export const buildCanvasSDXLInpaintGraph = (
|
||||
state: RootState,
|
||||
canvasInitImage: ImageDTO,
|
||||
canvasMaskImage: ImageDTO
|
||||
): Promise<NonNullableGraph> => {
|
||||
): NonNullableGraph => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -426,31 +426,24 @@ export const buildCanvasSDXLInpaintGraph = async (
|
||||
|
||||
// Add Refiner if enabled
|
||||
if (refinerModel) {
|
||||
await addSDXLRefinerToGraph(
|
||||
state,
|
||||
graph,
|
||||
SDXL_DENOISE_LATENTS,
|
||||
modelLoaderNodeId,
|
||||
canvasInitImage,
|
||||
canvasMaskImage
|
||||
);
|
||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId, canvasInitImage, canvasMaskImage);
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||
}
|
||||
}
|
||||
|
||||
// Add VAE
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -45,11 +45,11 @@ import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBu
|
||||
/**
|
||||
* Builds the Canvas tab's Outpaint graph.
|
||||
*/
|
||||
export const buildCanvasSDXLOutpaintGraph = async (
|
||||
export const buildCanvasSDXLOutpaintGraph = (
|
||||
state: RootState,
|
||||
canvasInitImage: ImageDTO,
|
||||
canvasMaskImage?: ImageDTO
|
||||
): Promise<NonNullableGraph> => {
|
||||
): NonNullableGraph => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -555,25 +555,25 @@ export const buildCanvasSDXLOutpaintGraph = async (
|
||||
|
||||
// Add Refiner if enabled
|
||||
if (refinerModel) {
|
||||
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId, canvasInitImage);
|
||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId, canvasInitImage);
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||
}
|
||||
}
|
||||
|
||||
// Add VAE
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -1,7 +1,6 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||
import type { NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
@ -25,12 +24,12 @@ import {
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
||||
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Text to Image graph.
|
||||
*/
|
||||
export const buildCanvasSDXLTextToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
|
||||
export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGraph => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -273,8 +272,6 @@ export const buildCanvasSDXLTextToImageGraph = async (state: RootState): Promise
|
||||
});
|
||||
}
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
@ -287,7 +284,7 @@ export const buildCanvasSDXLTextToImageGraph = async (state: RootState): Promise
|
||||
negative_prompt: negativePrompt,
|
||||
positive_style_prompt: positiveStylePrompt,
|
||||
negative_style_prompt: negativeStylePrompt,
|
||||
model: getModelMetadataField(modelConfig),
|
||||
model,
|
||||
seed,
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
@ -304,24 +301,24 @@ export const buildCanvasSDXLTextToImageGraph = async (state: RootState): Promise
|
||||
|
||||
// Add Refiner if enabled
|
||||
if (refinerModel) {
|
||||
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||
}
|
||||
}
|
||||
|
||||
// add LoRA support
|
||||
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// optionally add custom VAE
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -1,8 +1,7 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||
import type { NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
@ -24,12 +23,12 @@ import {
|
||||
POSITIVE_CONDITIONING,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Text to Image graph.
|
||||
*/
|
||||
export const buildCanvasTextToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
|
||||
export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -263,8 +262,6 @@ export const buildCanvasTextToImageGraph = async (state: RootState): Promise<Non
|
||||
});
|
||||
}
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
@ -275,7 +272,7 @@ export const buildCanvasTextToImageGraph = async (state: RootState): Promise<Non
|
||||
height: !isUsingScaledDimensions ? height : scaledBoundingBoxDimensions.height,
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
model: getModelMetadataField(modelConfig),
|
||||
model,
|
||||
seed,
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
@ -292,17 +289,17 @@ export const buildCanvasTextToImageGraph = async (state: RootState): Promise<Non
|
||||
}
|
||||
|
||||
// optionally add custom VAE
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -1,13 +1,7 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import {
|
||||
type ImageResizeInvocation,
|
||||
type ImageToLatentsInvocation,
|
||||
isNonRefinerMainModelConfig,
|
||||
type NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
@ -30,12 +24,12 @@ import {
|
||||
RESIZE,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Image to Image tab graph.
|
||||
*/
|
||||
export const buildLinearImageToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
|
||||
export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -313,8 +307,6 @@ export const buildLinearImageToImageGraph = async (state: RootState): Promise<No
|
||||
});
|
||||
}
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
@ -325,7 +317,7 @@ export const buildLinearImageToImageGraph = async (state: RootState): Promise<No
|
||||
width,
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
model: getModelMetadataField(modelConfig),
|
||||
model,
|
||||
seed,
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
@ -344,17 +336,17 @@ export const buildLinearImageToImageGraph = async (state: RootState): Promise<No
|
||||
}
|
||||
|
||||
// optionally add custom VAE
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -1,12 +1,6 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import {
|
||||
type ImageResizeInvocation,
|
||||
type ImageToLatentsInvocation,
|
||||
isNonRefinerMainModelConfig,
|
||||
type NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
@ -31,12 +25,12 @@ import {
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
||||
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Image to Image tab graph.
|
||||
*/
|
||||
export const buildLinearSDXLImageToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
|
||||
export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableGraph => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -324,8 +318,6 @@ export const buildLinearSDXLImageToImageGraph = async (state: RootState): Promis
|
||||
});
|
||||
}
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
@ -336,7 +328,7 @@ export const buildLinearSDXLImageToImageGraph = async (state: RootState): Promis
|
||||
width,
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
model: getModelMetadataField(modelConfig),
|
||||
model,
|
||||
seed,
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
@ -357,25 +349,25 @@ export const buildLinearSDXLImageToImageGraph = async (state: RootState): Promis
|
||||
|
||||
// Add Refiner if enabled
|
||||
if (refinerModel) {
|
||||
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||
}
|
||||
}
|
||||
|
||||
// optionally add custom VAE
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// Add LoRA Support
|
||||
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -1,7 +1,6 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||
import type { NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
@ -24,9 +23,9 @@ import {
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
||||
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
|
||||
export const buildLinearSDXLTextToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
|
||||
export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGraph => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -222,8 +221,6 @@ export const buildLinearSDXLTextToImageGraph = async (state: RootState): Promise
|
||||
],
|
||||
};
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
@ -234,7 +231,7 @@ export const buildLinearSDXLTextToImageGraph = async (state: RootState): Promise
|
||||
width,
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
model: getModelMetadataField(modelConfig),
|
||||
model,
|
||||
seed,
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
@ -253,25 +250,25 @@ export const buildLinearSDXLTextToImageGraph = async (state: RootState): Promise
|
||||
|
||||
// Add Refiner if enabled
|
||||
if (refinerModel) {
|
||||
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||
}
|
||||
}
|
||||
|
||||
// optionally add custom VAE
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// add IP Adapter
|
||||
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -1,8 +1,7 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||
import type { NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addHrfToGraph } from './addHrfToGraph';
|
||||
@ -24,9 +23,9 @@ import {
|
||||
SEAMLESS,
|
||||
TEXT_TO_IMAGE_GRAPH,
|
||||
} from './constants';
|
||||
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
|
||||
export const buildLinearTextToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
|
||||
export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -213,8 +212,6 @@ export const buildLinearTextToImageGraph = async (state: RootState): Promise<Non
|
||||
],
|
||||
};
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
@ -225,7 +222,7 @@ export const buildLinearTextToImageGraph = async (state: RootState): Promise<Non
|
||||
width,
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
model: getModelMetadataField(modelConfig),
|
||||
model,
|
||||
seed,
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
@ -242,18 +239,18 @@ export const buildLinearTextToImageGraph = async (state: RootState): Promise<Non
|
||||
}
|
||||
|
||||
// optionally add custom VAE
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// add IP Adapter
|
||||
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// High resolution fix.
|
||||
if (state.hrf.hrfEnabled) {
|
||||
|
@ -1,5 +1,5 @@
|
||||
import type { JSONObject } from 'common/types';
|
||||
import type { AnyModelConfig, CoreMetadataInvocation, ModelMetadataField, NonNullableGraph } from 'services/api/types';
|
||||
import type { CoreMetadataInvocation, NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { METADATA } from './constants';
|
||||
|
||||
@ -71,11 +71,3 @@ export const setMetadataReceivingNode = (graph: NonNullableGraph, nodeId: string
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const getModelMetadataField = ({ key, hash, name, base, type }: AnyModelConfig): ModelMetadataField => ({
|
||||
key,
|
||||
hash,
|
||||
name,
|
||||
base,
|
||||
type,
|
||||
});
|
||||
|
@ -38,7 +38,6 @@ export const ParamNegativePrompt = memo(() => {
|
||||
onKeyDown={onKeyDown}
|
||||
fontSize="sm"
|
||||
variant="darkFilled"
|
||||
paddingRight={30}
|
||||
/>
|
||||
<PromptOverlayButtonWrapper>
|
||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||
|
@ -54,7 +54,6 @@ export const ParamPositivePrompt = memo(() => {
|
||||
minH={28}
|
||||
onKeyDown={onKeyDown}
|
||||
variant="darkFilled"
|
||||
paddingRight={30}
|
||||
/>
|
||||
<PromptOverlayButtonWrapper>
|
||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||
|
@ -41,7 +41,6 @@ export const ParamSDXLNegativeStylePrompt = memo(() => {
|
||||
onKeyDown={onKeyDown}
|
||||
fontSize="sm"
|
||||
variant="darkFilled"
|
||||
paddingRight={30}
|
||||
/>
|
||||
<PromptOverlayButtonWrapper>
|
||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||
|
@ -38,7 +38,6 @@ export const ParamSDXLPositiveStylePrompt = memo(() => {
|
||||
onKeyDown={onKeyDown}
|
||||
fontSize="sm"
|
||||
variant="darkFilled"
|
||||
paddingRight={30}
|
||||
/>
|
||||
<PromptOverlayButtonWrapper>
|
||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||
|
@ -109,7 +109,7 @@ export const imagesApi = api.injectEndpoints({
|
||||
}),
|
||||
clearIntermediates: build.mutation<number, void>({
|
||||
query: () => ({ url: buildImagesUrl('intermediates'), method: 'DELETE' }),
|
||||
invalidatesTags: ['IntermediatesCount', 'InvocationCacheStatus'],
|
||||
invalidatesTags: ['IntermediatesCount'],
|
||||
}),
|
||||
getImageDTO: build.query<ImageDTO, string>({
|
||||
query: (image_name) => ({ url: buildImagesUrl(`i/${image_name}`) }),
|
||||
@ -125,7 +125,7 @@ export const imagesApi = api.injectEndpoints({
|
||||
paths['/api/v1/images/i/{image_name}/workflow']['get']['responses']['200']['content']['application/json'],
|
||||
string
|
||||
>({
|
||||
query: (image_name) => ({ url: buildImagesUrl(`i/${image_name}/workflow`) }),
|
||||
query: (image_name) => ({ url: `images/i/${image_name}/workflow` }),
|
||||
providesTags: (result, error, image_name) => [{ type: 'ImageWorkflow', id: image_name }],
|
||||
keepUnusedDataFor: 86400, // 24 hours
|
||||
}),
|
||||
|
@ -23,14 +23,7 @@ export type UpdateModelArg = {
|
||||
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
|
||||
};
|
||||
|
||||
export type UpdateModelImageArg = {
|
||||
key: string;
|
||||
image: Blob;
|
||||
};
|
||||
|
||||
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
|
||||
type UpdateModelImageResponse =
|
||||
paths['/api/v2/models/i/{key}/image']['patch']['responses']['200']['content']['application/json'];
|
||||
|
||||
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
@ -40,7 +33,6 @@ type DeleteModelArg = {
|
||||
key: string;
|
||||
};
|
||||
type DeleteModelResponse = void;
|
||||
type DeleteModelImageResponse = void;
|
||||
|
||||
type ConvertMainModelResponse =
|
||||
paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];
|
||||
@ -152,18 +144,6 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
updateModelImage: build.mutation<UpdateModelImageResponse, UpdateModelImageArg>({
|
||||
query: ({ key, image }) => {
|
||||
const formData = new FormData();
|
||||
formData.append('image', image);
|
||||
return {
|
||||
url: buildModelsUrl(`i/${key}/image`),
|
||||
method: 'PATCH',
|
||||
body: formData,
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
|
||||
query: ({ source }) => {
|
||||
return {
|
||||
@ -183,18 +163,6 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
deleteModelImage: build.mutation<DeleteModelImageResponse, string>({
|
||||
query: (key) => {
|
||||
return {
|
||||
url: buildModelsUrl(`i/${key}/image`),
|
||||
method: 'DELETE',
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
getModelImage: build.query<string, string>({
|
||||
query: (key) => buildModelsUrl(`i/${key}/image`),
|
||||
}),
|
||||
convertModel: build.mutation<ConvertMainModelResponse, string>({
|
||||
query: (key) => {
|
||||
return {
|
||||
@ -362,9 +330,7 @@ export const {
|
||||
useGetTextualInversionModelsQuery,
|
||||
useGetVaeModelsQuery,
|
||||
useDeleteModelsMutation,
|
||||
useDeleteModelImageMutation,
|
||||
useUpdateModelMutation,
|
||||
useUpdateModelImageMutation,
|
||||
useInstallModelMutation,
|
||||
useConvertModelMutation,
|
||||
useSyncModelsMutation,
|
||||
|
File diff suppressed because one or more lines are too long
@ -38,6 +38,7 @@ export type OffsetPaginatedResults_ImageDTO_ = S['OffsetPaginatedResults_ImageDT
|
||||
export type ModelType = S['ModelType'];
|
||||
export type SubModelType = S['SubModelType'];
|
||||
export type BaseModelType = S['BaseModelType'];
|
||||
export type ControlField = S['ControlField'];
|
||||
|
||||
// Model Configs
|
||||
|
||||
@ -119,18 +120,17 @@ export type CreateGradientMaskInvocation = S['CreateGradientMaskInvocation'];
|
||||
export type CanvasPasteBackInvocation = S['CanvasPasteBackInvocation'];
|
||||
export type NoiseInvocation = S['NoiseInvocation'];
|
||||
export type DenoiseLatentsInvocation = S['DenoiseLatentsInvocation'];
|
||||
export type SDXLLoRALoaderInvocation = S['SDXLLoRALoaderInvocation'];
|
||||
export type SDXLLoraLoaderInvocation = S['SDXLLoraLoaderInvocation'];
|
||||
export type ImageToLatentsInvocation = S['ImageToLatentsInvocation'];
|
||||
export type LatentsToImageInvocation = S['LatentsToImageInvocation'];
|
||||
export type LoRALoaderInvocation = S['LoRALoaderInvocation'];
|
||||
export type LoraLoaderInvocation = S['LoraLoaderInvocation'];
|
||||
export type ESRGANInvocation = S['ESRGANInvocation'];
|
||||
export type ImageNSFWBlurInvocation = S['ImageNSFWBlurInvocation'];
|
||||
export type ImageWatermarkInvocation = S['ImageWatermarkInvocation'];
|
||||
export type SeamlessModeInvocation = S['SeamlessModeInvocation'];
|
||||
export type CoreMetadataInvocation = S['CoreMetadataInvocation'];
|
||||
|
||||
// Metadata fields
|
||||
export type ModelMetadataField = S['ModelMetadataField'];
|
||||
export type IPAdapterMetadataField = S['IPAdapterMetadataField'];
|
||||
export type T2IAdapterField = S['T2IAdapterField'];
|
||||
|
||||
// ControlNet Nodes
|
||||
export type ControlNetInvocation = S['ControlNetInvocation'];
|
||||
|
@ -33,15 +33,19 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.latent import SchedulerOutput
|
||||
from invokeai.app.invocations.metadata import MetadataItemField, MetadataItemOutput, MetadataOutput
|
||||
from invokeai.app.invocations.model import (
|
||||
CLIPField,
|
||||
ClipField,
|
||||
CLIPOutput,
|
||||
LoRALoaderOutput,
|
||||
ModelField,
|
||||
LoraInfo,
|
||||
LoraLoaderOutput,
|
||||
LoRAModelField,
|
||||
MainModelField,
|
||||
ModelInfo,
|
||||
ModelLoaderOutput,
|
||||
SDXLLoRALoaderOutput,
|
||||
SDXLLoraLoaderOutput,
|
||||
UNetField,
|
||||
UNetOutput,
|
||||
VAEField,
|
||||
VaeField,
|
||||
VAEModelField,
|
||||
VAEOutput,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import (
|
||||
@ -69,8 +73,8 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.model_management.model_manager import LoadedModelInfo
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
@ -114,16 +118,20 @@ __all__ = [
|
||||
"MetadataItemOutput",
|
||||
"MetadataOutput",
|
||||
# invokeai.app.invocations.model
|
||||
"ModelField",
|
||||
"ModelInfo",
|
||||
"LoraInfo",
|
||||
"UNetField",
|
||||
"CLIPField",
|
||||
"VAEField",
|
||||
"ClipField",
|
||||
"VaeField",
|
||||
"MainModelField",
|
||||
"LoRAModelField",
|
||||
"VAEModelField",
|
||||
"UNetOutput",
|
||||
"VAEOutput",
|
||||
"CLIPOutput",
|
||||
"ModelLoaderOutput",
|
||||
"LoRALoaderOutput",
|
||||
"SDXLLoRALoaderOutput",
|
||||
"LoraLoaderOutput",
|
||||
"SDXLLoraLoaderOutput",
|
||||
# invokeai.app.invocations.primitives
|
||||
"BooleanCollectionOutput",
|
||||
"BooleanOutput",
|
||||
@ -158,7 +166,7 @@ __all__ = [
|
||||
# invokeai.app.services.config.config_default
|
||||
"InvokeAIAppConfig",
|
||||
# invokeai.backend.model_management.model_manager
|
||||
"LoadedModel",
|
||||
"LoadedModelInfo",
|
||||
# invokeai.backend.model_management.models.base
|
||||
"BaseModelType",
|
||||
"ModelType",
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user