Compare commits

..

1 Commits

Author SHA1 Message Date
3dfa8ce11f feat(ui): upgrade to reactflow 12 2024-08-10 20:40:34 +10:00
222 changed files with 20273 additions and 28159 deletions

View File

@ -62,7 +62,7 @@ jobs:
- name: install ruff
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: pip install ruff==0.6.0
run: pip install ruff
shell: bash
- name: ruff check

View File

@ -60,7 +60,7 @@ jobs:
extra-index-url: 'https://download.pytorch.org/whl/cpu'
github-env: $GITHUB_ENV
- platform: macos-default
os: macOS-14
os: macOS-12
github-env: $GITHUB_ENV
- platform: windows-cpu
os: windows-2022

View File

@ -1,22 +1,20 @@
# Invoke in Docker
First things first:
- Ensure that Docker can use your [NVIDIA][nvidia docker docs] or [AMD][amd docker docs] GPU.
- This document assumes a Linux system, but should work similarly under Windows with WSL2.
- Ensure that Docker can use the GPU on your system
- This documentation assumes Linux, but should work similarly under Windows with WSL2
- We don't recommend running Invoke in Docker on macOS at this time. It works, but very slowly.
## Quickstart
## Quickstart :lightning:
No `docker compose`, no persistence, single command, using the official images:
No `docker compose`, no persistence, just a simple one-liner using the official images:
**CUDA (NVIDIA GPU):**
**CUDA:**
```bash
docker run --runtime=nvidia --gpus=all --publish 9090:9090 ghcr.io/invoke-ai/invokeai
```
**ROCm (AMD GPU):**
**ROCm:**
```bash
docker run --device /dev/kfd --device /dev/dri --publish 9090:9090 ghcr.io/invoke-ai/invokeai:main-rocm
@ -24,20 +22,12 @@ docker run --device /dev/kfd --device /dev/dri --publish 9090:9090 ghcr.io/invok
Open `http://localhost:9090` in your browser once the container finishes booting, install some models, and generate away!
### Data persistence
To persist your generated images and downloaded models outside of the container, add a `--volume/-v` flag to the above command, e.g.:
```bash
docker run --volume /some/local/path:/invokeai {...etc...}
```
`/some/local/path/invokeai` will contain all your data.
It can *usually* be reused between different installs of Invoke. Tread with caution and read the release notes!
> [!TIP]
> To persist your data (including downloaded models) outside of the container, add a `--volume/-v` flag to the above command, e.g.: `docker run --volume /some/local/path:/invokeai <...the rest of the command>`
## Customize the container
The included `run.sh` script is a convenience wrapper around `docker compose`. It can be helpful for passing additional build arguments to `docker compose`. Alternatively, the familiar `docker compose` commands work just as well.
We ship the `run.sh` script, which is a convenient wrapper around `docker compose` for cases where custom image build args are needed. Alternatively, the familiar `docker compose` commands work just as well.
```bash
cd docker
@ -48,14 +38,11 @@ cp .env.sample .env
It will take a few minutes to build the image the first time. Once the application starts up, open `http://localhost:9090` in your browser to invoke!
>[!TIP]
>When using the `run.sh` script, the container will continue running after Ctrl+C. To shut it down, use the `docker compose down` command.
## Docker setup in detail
#### Linux
1. Ensure buildkit is enabled in the Docker daemon settings (`/etc/docker/daemon.json`)
1. Ensure builkit is enabled in the Docker daemon settings (`/etc/docker/daemon.json`)
2. Install the `docker compose` plugin using your package manager, or follow a [tutorial](https://docs.docker.com/compose/install/linux/#install-using-the-repository).
- The deprecated `docker-compose` (hyphenated) CLI probably won't work. Update to a recent version.
3. Ensure docker daemon is able to access the GPU.
@ -111,7 +98,25 @@ GPU_DRIVER=cuda
Any environment variables supported by InvokeAI can be set here. See the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail.
---
## Even More Customizing!
[nvidia docker docs]: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
[amd docker docs]: https://rocm.docs.amd.com/projects/install-on-linux/en/latest/how-to/docker.html
See the `docker-compose.yml` file. The `command` instruction can be uncommented and used to run arbitrary startup commands. Some examples below.
### Reconfigure the runtime directory
Can be used to download additional models from the supported model list
In conjunction with `INVOKEAI_ROOT` can be also used to initialize a runtime directory
```yaml
command:
- invokeai-configure
- --yes
```
Or install models:
```yaml
command:
- invokeai-model-install
```

View File

@ -299,7 +299,7 @@ Migration logic is in [migrations.ts].
[pydantic]: https://github.com/pydantic/pydantic 'pydantic'
[zod]: https://github.com/colinhacks/zod 'zod'
[openapi-types]: https://github.com/kogosoftwarellc/open-api/tree/main/packages/openapi-types 'openapi-types'
[reactflow]: https://github.com/xyflow/xyflow 'reactflow'
[reactflow]: https://github.com/xyflow/xyflow '@xyflow/react'
[reactflow-concepts]: https://reactflow.dev/learn/concepts/terms-and-definitions
[reactflow-events]: https://reactflow.dev/api-reference/react-flow#event-handlers
[buildWorkflow.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts

View File

@ -17,7 +17,7 @@
set -eu
# Ensure we're in the correct folder in case user's CWD is somewhere else
scriptdir=$(dirname $(readlink -f "$0"))
scriptdir=$(dirname "$0")
cd "$scriptdir"
. .venv/bin/activate

View File

@ -1,6 +1,5 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
from logging import Logger
import torch
@ -32,8 +31,6 @@ from invokeai.app.services.session_processor.session_processor_default import (
)
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.app.services.style_preset_images.style_preset_images_disk import StylePresetImageFileStorageDisk
from invokeai.app.services.style_preset_records.style_preset_records_sqlite import SqliteStylePresetRecordsStorage
from invokeai.app.services.urls.urls_default import LocalUrlService
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
@ -66,12 +63,7 @@ class ApiDependencies:
invoker: Invoker
@staticmethod
def initialize(
config: InvokeAIAppConfig,
event_handler_id: int,
loop: asyncio.AbstractEventLoop,
logger: Logger = logger,
) -> None:
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger) -> None:
logger.info(f"InvokeAI version {__version__}")
logger.info(f"Root directory = {str(config.root_path)}")
@ -82,7 +74,6 @@ class ApiDependencies:
image_files = DiskImageFileStorage(f"{output_folder}/images")
model_images_folder = config.models_path
style_presets_folder = config.style_presets_path
db = init_db(config=config, logger=logger, image_files=image_files)
@ -93,7 +84,7 @@ class ApiDependencies:
board_images = BoardImagesService()
board_records = SqliteBoardRecordStorage(db=db)
boards = BoardService()
events = FastAPIEventService(event_handler_id, loop=loop)
events = FastAPIEventService(event_handler_id)
bulk_download = BulkDownloadService()
image_records = SqliteImageRecordStorage(db=db)
images = ImageService()
@ -118,8 +109,6 @@ class ApiDependencies:
session_queue = SqliteSessionQueue(db=db)
urls = LocalUrlService()
workflow_records = SqliteWorkflowRecordsStorage(db=db)
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
services = InvocationServices(
board_image_records=board_image_records,
@ -145,8 +134,6 @@ class ApiDependencies:
workflow_records=workflow_records,
tensors=tensors,
conditioning=conditioning,
style_preset_records=style_preset_records,
style_preset_image_files=style_preset_image_files,
)
ApiDependencies.invoker = Invoker(services)

View File

@ -218,8 +218,9 @@ async def get_image_workflow(
raise HTTPException(status_code=404)
@images_router.get(
@images_router.api_route(
"/i/{image_name}/full",
methods=["GET", "HEAD"],
operation_id="get_image_full",
response_class=Response,
responses={
@ -230,18 +231,6 @@ async def get_image_workflow(
404: {"description": "Image not found"},
},
)
@images_router.head(
"/i/{image_name}/full",
operation_id="get_image_full_head",
response_class=Response,
responses={
200: {
"description": "Return the full-resolution image",
"content": {"image/png": {}},
},
404: {"description": "Image not found"},
},
)
async def get_image_full(
image_name: str = Path(description="The name of full-resolution image file to get"),
) -> Response:
@ -253,7 +242,6 @@ async def get_image_full(
content = f.read()
response = Response(content, media_type="image/png")
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
response.headers["Content-Disposition"] = f'inline; filename="{image_name}"'
return response
except Exception:
raise HTTPException(status_code=404)

View File

@ -1,274 +0,0 @@
import csv
import io
import json
import traceback
from typing import Optional
import pydantic
from fastapi import APIRouter, File, Form, HTTPException, Path, Response, UploadFile
from fastapi.responses import FileResponse
from PIL import Image
from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.api.routers.model_manager import IMAGE_MAX_AGE
from invokeai.app.services.style_preset_images.style_preset_images_common import StylePresetImageFileNotFoundException
from invokeai.app.services.style_preset_records.style_preset_records_common import (
InvalidPresetImportDataError,
PresetData,
PresetType,
StylePresetChanges,
StylePresetNotFoundError,
StylePresetRecordWithImage,
StylePresetWithoutId,
UnsupportedFileTypeError,
parse_presets_from_file,
)
class StylePresetFormData(BaseModel):
name: str = Field(description="Preset name")
positive_prompt: str = Field(description="Positive prompt")
negative_prompt: str = Field(description="Negative prompt")
type: PresetType = Field(description="Preset type")
style_presets_router = APIRouter(prefix="/v1/style_presets", tags=["style_presets"])
@style_presets_router.get(
"/i/{style_preset_id}",
operation_id="get_style_preset",
responses={
200: {"model": StylePresetRecordWithImage},
},
)
async def get_style_preset(
style_preset_id: str = Path(description="The style preset to get"),
) -> StylePresetRecordWithImage:
"""Gets a style preset"""
try:
image = ApiDependencies.invoker.services.style_preset_image_files.get_url(style_preset_id)
style_preset = ApiDependencies.invoker.services.style_preset_records.get(style_preset_id)
return StylePresetRecordWithImage(image=image, **style_preset.model_dump())
except StylePresetNotFoundError:
raise HTTPException(status_code=404, detail="Style preset not found")
@style_presets_router.patch(
"/i/{style_preset_id}",
operation_id="update_style_preset",
responses={
200: {"model": StylePresetRecordWithImage},
},
)
async def update_style_preset(
image: Optional[UploadFile] = File(description="The image file to upload", default=None),
style_preset_id: str = Path(description="The id of the style preset to update"),
data: str = Form(description="The data of the style preset to update"),
) -> StylePresetRecordWithImage:
"""Updates a style preset"""
if image is not None:
if not image.content_type or not image.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await image.read()
try:
pil_image = Image.open(io.BytesIO(contents))
except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image")
try:
ApiDependencies.invoker.services.style_preset_image_files.save(style_preset_id, pil_image)
except ValueError as e:
raise HTTPException(status_code=409, detail=str(e))
else:
try:
ApiDependencies.invoker.services.style_preset_image_files.delete(style_preset_id)
except StylePresetImageFileNotFoundException:
pass
try:
parsed_data = json.loads(data)
validated_data = StylePresetFormData(**parsed_data)
name = validated_data.name
type = validated_data.type
positive_prompt = validated_data.positive_prompt
negative_prompt = validated_data.negative_prompt
except pydantic.ValidationError:
raise HTTPException(status_code=400, detail="Invalid preset data")
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
changes = StylePresetChanges(name=name, preset_data=preset_data, type=type)
style_preset_image = ApiDependencies.invoker.services.style_preset_image_files.get_url(style_preset_id)
style_preset = ApiDependencies.invoker.services.style_preset_records.update(
style_preset_id=style_preset_id, changes=changes
)
return StylePresetRecordWithImage(image=style_preset_image, **style_preset.model_dump())
@style_presets_router.delete(
"/i/{style_preset_id}",
operation_id="delete_style_preset",
)
async def delete_style_preset(
style_preset_id: str = Path(description="The style preset to delete"),
) -> None:
"""Deletes a style preset"""
try:
ApiDependencies.invoker.services.style_preset_image_files.delete(style_preset_id)
except StylePresetImageFileNotFoundException:
pass
ApiDependencies.invoker.services.style_preset_records.delete(style_preset_id)
@style_presets_router.post(
"/",
operation_id="create_style_preset",
responses={
200: {"model": StylePresetRecordWithImage},
},
)
async def create_style_preset(
image: Optional[UploadFile] = File(description="The image file to upload", default=None),
data: str = Form(description="The data of the style preset to create"),
) -> StylePresetRecordWithImage:
"""Creates a style preset"""
try:
parsed_data = json.loads(data)
validated_data = StylePresetFormData(**parsed_data)
name = validated_data.name
type = validated_data.type
positive_prompt = validated_data.positive_prompt
negative_prompt = validated_data.negative_prompt
except pydantic.ValidationError:
raise HTTPException(status_code=400, detail="Invalid preset data")
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
style_preset = StylePresetWithoutId(name=name, preset_data=preset_data, type=type)
new_style_preset = ApiDependencies.invoker.services.style_preset_records.create(style_preset=style_preset)
if image is not None:
if not image.content_type or not image.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await image.read()
try:
pil_image = Image.open(io.BytesIO(contents))
except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image")
try:
ApiDependencies.invoker.services.style_preset_image_files.save(new_style_preset.id, pil_image)
except ValueError as e:
raise HTTPException(status_code=409, detail=str(e))
preset_image = ApiDependencies.invoker.services.style_preset_image_files.get_url(new_style_preset.id)
return StylePresetRecordWithImage(image=preset_image, **new_style_preset.model_dump())
@style_presets_router.get(
"/",
operation_id="list_style_presets",
responses={
200: {"model": list[StylePresetRecordWithImage]},
},
)
async def list_style_presets() -> list[StylePresetRecordWithImage]:
"""Gets a page of style presets"""
style_presets_with_image: list[StylePresetRecordWithImage] = []
style_presets = ApiDependencies.invoker.services.style_preset_records.get_many()
for preset in style_presets:
image = ApiDependencies.invoker.services.style_preset_image_files.get_url(preset.id)
style_preset_with_image = StylePresetRecordWithImage(image=image, **preset.model_dump())
style_presets_with_image.append(style_preset_with_image)
return style_presets_with_image
@style_presets_router.get(
"/i/{style_preset_id}/image",
operation_id="get_style_preset_image",
responses={
200: {
"description": "The style preset image was fetched successfully",
},
400: {"description": "Bad request"},
404: {"description": "The style preset image could not be found"},
},
status_code=200,
)
async def get_style_preset_image(
style_preset_id: str = Path(description="The id of the style preset image to get"),
) -> FileResponse:
"""Gets an image file that previews the model"""
try:
path = ApiDependencies.invoker.services.style_preset_image_files.get_path(style_preset_id)
response = FileResponse(
path,
media_type="image/png",
filename=style_preset_id + ".png",
content_disposition_type="inline",
)
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
return response
except Exception:
raise HTTPException(status_code=404)
@style_presets_router.get(
"/export",
operation_id="export_style_presets",
responses={200: {"content": {"text/csv": {}}, "description": "A CSV file with the requested data."}},
status_code=200,
)
async def export_style_presets():
# Create an in-memory stream to store the CSV data
output = io.StringIO()
writer = csv.writer(output)
# Write the header
writer.writerow(["name", "prompt", "negative_prompt"])
style_presets = ApiDependencies.invoker.services.style_preset_records.get_many(type=PresetType.User)
for preset in style_presets:
writer.writerow([preset.name, preset.preset_data.positive_prompt, preset.preset_data.negative_prompt])
csv_data = output.getvalue()
output.close()
return Response(
content=csv_data,
media_type="text/csv",
headers={"Content-Disposition": "attachment; filename=prompt_templates.csv"},
)
@style_presets_router.post(
"/import",
operation_id="import_style_presets",
)
async def import_style_presets(file: UploadFile = File(description="The file to import")):
try:
style_presets = await parse_presets_from_file(file)
ApiDependencies.invoker.services.style_preset_records.create_many(style_presets)
except InvalidPresetImportDataError as e:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=400, detail=str(e))
except UnsupportedFileTypeError as e:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail=str(e))

View File

@ -30,7 +30,6 @@ from invokeai.app.api.routers import (
images,
model_manager,
session_queue,
style_presets,
utilities,
workflows,
)
@ -56,13 +55,11 @@ mimetypes.add_type("text/css", ".css")
torch_device_name = TorchDevice.get_torch_device_name()
logger.info(f"Using torch device: {torch_device_name}")
loop = asyncio.new_event_loop()
@asynccontextmanager
async def lifespan(app: FastAPI):
# Add startup event to load dependencies
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, loop=loop, logger=logger)
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
yield
# Shut down threads
ApiDependencies.shutdown()
@ -109,7 +106,6 @@ app.include_router(board_images.board_images_router, prefix="/api")
app.include_router(app_info.app_router, prefix="/api")
app.include_router(session_queue.session_queue_router, prefix="/api")
app.include_router(workflows.workflows_router, prefix="/api")
app.include_router(style_presets.style_presets_router, prefix="/api")
app.openapi = get_openapi_func(app)
@ -188,6 +184,8 @@ def invoke_api() -> None:
check_cudnn(logger)
# Start our own event loop for eventing usage
loop = asyncio.new_event_loop()
config = uvicorn.Config(
app=app,
host=app_config.host,

View File

@ -40,18 +40,14 @@ class UIType(str, Enum, metaclass=MetaEnum):
# region Model Field Types
MainModel = "MainModelField"
FluxMainModel = "FluxMainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
ONNXModel = "ONNXModelField"
VAEModel = "VAEModelField"
FluxVAEModel = "FluxVAEModelField"
LoRAModel = "LoRAModelField"
ControlNetModel = "ControlNetModelField"
IPAdapterModel = "IPAdapterModelField"
T2IAdapterModel = "T2IAdapterModelField"
T5EncoderModel = "T5EncoderModelField"
CLIPEmbedModel = "CLIPEmbedModelField"
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
# endregion
@ -129,17 +125,13 @@ class FieldDescriptions:
negative_cond = "Negative conditioning tensor"
noise = "Noise tensor"
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
t5_encoder = "T5 tokenizer and text encoder"
clip_embed_model = "CLIP Embed loader"
unet = "UNet (scheduler, LoRAs)"
transformer = "Transformer"
vae = "VAE"
cond = "Conditioning tensor"
controlnet_model = "ControlNet model to load"
vae_model = "VAE model to load"
lora_model = "LoRA model to load"
main_model = "Main model (UNet, VAE, CLIP) to load"
flux_model = "Flux model (Transformer) to load"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
@ -239,12 +231,6 @@ class ColorField(BaseModel):
return (self.r, self.g, self.b, self.a)
class FluxConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
class ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""

View File

@ -1,92 +0,0 @@
from typing import Literal
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import FluxConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
@invocation(
"flux_text_encoder",
title="FLUX Text Encoding",
tags=["prompt", "conditioning", "flux"],
category="conditioning",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxTextEncoderInvocation(BaseInvocation):
"""Encodes and preps a prompt for a flux image."""
clip: CLIPField = InputField(
title="CLIP",
description=FieldDescriptions.clip,
input=Input.Connection,
)
t5_encoder: T5EncoderField = InputField(
title="T5Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
t5_max_seq_len: Literal[256, 512] = InputField(
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
)
prompt: str = InputField(description="Text prompt to encode.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
# Note: The T5 and CLIP encoding are done in separate functions to ensure that all model references are locally
# scoped. This ensures that the T5 model can be freed and gc'd before loading the CLIP model (if necessary).
t5_embeddings = self._t5_encode(context)
clip_embeddings = self._clip_encode(context)
conditioning_data = ConditioningFieldData(
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
)
conditioning_name = context.conditioning.save(conditioning_data)
return FluxConditioningOutput.build(conditioning_name)
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
prompt = [self.prompt]
with (
t5_text_encoder_info as t5_text_encoder,
t5_tokenizer_info as t5_tokenizer,
):
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, T5Tokenizer)
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
prompt_embeds = t5_encoder(prompt)
assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
prompt = [self.prompt]
with (
clip_text_encoder_info as clip_text_encoder,
clip_tokenizer_info as clip_tokenizer,
):
assert isinstance(clip_text_encoder, CLIPTextModel)
assert isinstance(clip_tokenizer, CLIPTokenizer)
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
pooled_prompt_embeds = clip_encoder(prompt)
assert isinstance(pooled_prompt_embeds, torch.Tensor)
return pooled_prompt_embeds

View File

@ -1,169 +0,0 @@
import torch
from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
FluxConditioningField,
Input,
InputField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@invocation(
"flux_text_to_image",
title="FLUX Text to Image",
tags=["image", "flux"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Text-to-image generation using a FLUX model."""
transformer: TransformerField = InputField(
description=FieldDescriptions.flux_model,
input=Input.Connection,
title="Transformer",
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
positive_text_conditioning: FluxConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
num_steps: int = InputField(
default=4, description="Number of diffusion steps. Recommend values are schnell: 4, dev: 50."
)
guidance: float = InputField(
default=4.0,
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = self._run_diffusion(context)
image = self._run_vae_decoding(context, latents)
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)
def _run_diffusion(
self,
context: InvocationContext,
):
inference_dtype = torch.bfloat16
# Load the conditioning data.
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
t5_embeddings = flux_conditioning.t5_embeds
clip_embeddings = flux_conditioning.clip_embeds
transformer_info = context.models.load(self.transformer.transformer)
# Prepare input noise.
x = get_noise(
num_samples=1,
height=self.height,
width=self.width,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
seed=self.seed,
)
x, img_ids = prepare_latent_img_patches(x)
is_schnell = "schnell" in transformer_info.config.config_path
timesteps = get_schedule(
num_steps=self.num_steps,
image_seq_len=x.shape[1],
shift=not is_schnell,
)
bs, t5_seq_len, _ = t5_embeddings.shape
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
with transformer_info as transformer:
assert isinstance(transformer, Flux)
def step_callback() -> None:
if context.util.is_canceled():
raise CanceledException
# TODO: Make this look like the image before re-enabling
# latent_image = unpack(img.float(), self.height, self.width)
# latent_image = latent_image.squeeze() # Remove unnecessary dimensions
# flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128]
# # Create a new tensor of the required shape [255, 255, 3]
# latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format
# # Convert to a NumPy array and then to a PIL Image
# image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8))
# (width, height) = image.size
# width *= 8
# height *= 8
# dataURL = image_to_dataURL(image, image_format="JPEG")
# # TODO: move this whole function to invocation context to properly reference these variables
# context._services.events.emit_invocation_denoise_progress(
# context._data.queue_item,
# context._data.invocation,
# state,
# ProgressImage(dataURL=dataURL, width=width, height=height),
# )
x = denoise(
model=transformer,
img=x,
img_ids=img_ids,
txt=t5_embeddings,
txt_ids=txt_ids,
vec=clip_embeddings,
timesteps=timesteps,
step_callback=step_callback,
guidance=self.guidance,
)
x = unpack(x.float(), self.height, self.width)
return x
def _run_vae_decoding(
self,
context: InvocationContext,
latents: torch.Tensor,
) -> Image.Image:
vae_info = context.models.load(self.vae.vae)
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
latents = latents.to(dtype=TorchDevice.choose_torch_dtype())
img = vae.decode(latents)
img = img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c")
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
return img_pil

View File

@ -1,5 +1,5 @@
import copy
from typing import List, Literal, Optional
from typing import List, Optional
from pydantic import BaseModel, Field
@ -13,14 +13,7 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
CheckpointConfigBase,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
class ModelIdentifierField(BaseModel):
@ -67,15 +60,6 @@ class CLIPField(BaseModel):
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
class T5EncoderField(BaseModel):
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
class VAEField(BaseModel):
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
@ -138,78 +122,6 @@ class ModelIdentifierInvocation(BaseInvocation):
return ModelIdentifierOutput(model=self.model)
@invocation_output("flux_model_loader_output")
class FluxModelLoaderOutput(BaseInvocationOutput):
"""Flux base model loader output"""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
max_seq_len: Literal[256, 512] = OutputField(
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
title="Max Seq Length",
)
@invocation(
"flux_model_loader",
title="Flux Main Model",
tags=["model", "flux"],
category="model",
version="1.0.4",
classification=Classification.Prototype,
)
class FluxModelLoaderInvocation(BaseInvocation):
"""Loads a flux base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.flux_model,
ui_type=UIType.FluxMainModel,
input=Input.Direct,
)
t5_encoder_model: ModelIdentifierField = InputField(
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
)
clip_embed_model: ModelIdentifierField = InputField(
description=FieldDescriptions.clip_embed_model,
ui_type=UIType.CLIPEmbedModel,
input=Input.Direct,
title="CLIP Embed",
)
vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
)
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
if not context.models.exists(key):
raise ValueError(f"Unknown model: {key}")
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],
)
@invocation(
"main_model_loader",
title="Main Model",

View File

@ -12,7 +12,6 @@ from invokeai.app.invocations.fields import (
ConditioningField,
DenoiseMaskField,
FieldDescriptions,
FluxConditioningField,
ImageField,
Input,
InputField,
@ -415,17 +414,6 @@ class MaskOutput(BaseInvocationOutput):
height: int = OutputField(description="The height of the mask in pixels.")
@invocation_output("flux_conditioning_output")
class FluxConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single conditioning tensor"""
conditioning: FluxConditioningField = OutputField(description=FieldDescriptions.cond)
@classmethod
def build(cls, conditioning_name: str) -> "FluxConditioningOutput":
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
@invocation_output("conditioning_output")
class ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single conditioning tensor"""

View File

@ -91,7 +91,6 @@ class InvokeAIAppConfig(BaseSettings):
db_dir: Path to InvokeAI databases directory.
outputs_dir: Path to directory for outputs.
custom_nodes_dir: Path to directory for custom nodes.
style_presets_dir: Path to directory for style presets.
log_handlers: Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".
log_format: Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.<br>Valid values: `plain`, `color`, `syslog`, `legacy`
log_level: Emit logging messages at this level or higher.<br>Valid values: `debug`, `info`, `warning`, `error`, `critical`
@ -154,7 +153,6 @@ class InvokeAIAppConfig(BaseSettings):
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.")
custom_nodes_dir: Path = Field(default=Path("nodes"), description="Path to directory for custom nodes.")
style_presets_dir: Path = Field(default=Path("style_presets"), description="Path to directory for style presets.")
# LOGGING
log_handlers: list[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".')
@ -302,11 +300,6 @@ class InvokeAIAppConfig(BaseSettings):
"""Path to the models directory, resolved to an absolute path.."""
return self._resolve(self.models_dir)
@property
def style_presets_path(self) -> Path:
"""Path to the style presets directory, resolved to an absolute path.."""
return self._resolve(self.style_presets_dir)
@property
def convert_cache_path(self) -> Path:
"""Path to the converted cache models directory, resolved to an absolute path.."""

View File

@ -1,44 +1,46 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
import threading
from queue import Empty, Queue
from fastapi_events.dispatcher import dispatch
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.events.events_common import (
EventBase,
)
class FastAPIEventService(EventServiceBase):
def __init__(self, event_handler_id: int, loop: asyncio.AbstractEventLoop) -> None:
def __init__(self, event_handler_id: int) -> None:
self.event_handler_id = event_handler_id
self._queue = asyncio.Queue[EventBase | None]()
self._queue = Queue[EventBase | None]()
self._stop_event = threading.Event()
self._loop = loop
# We need to store a reference to the task so it doesn't get GC'd
# See: https://docs.python.org/3/library/asyncio-task.html#creating-tasks
self._background_tasks: set[asyncio.Task[None]] = set()
task = self._loop.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.remove)
asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
super().__init__()
def stop(self, *args, **kwargs):
self._stop_event.set()
self._loop.call_soon_threadsafe(self._queue.put_nowait, None)
self._queue.put(None)
def dispatch(self, event: EventBase) -> None:
self._loop.call_soon_threadsafe(self._queue.put_nowait, event)
self._queue.put(event)
async def _dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread"""
while not stop_event.is_set():
try:
event = await self._queue.get()
event = self._queue.get(block=False)
if not event: # Probably stopping
continue
# Leave the payloads as live pydantic models
dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False)
except Empty:
await asyncio.sleep(0.1)
pass
except asyncio.CancelledError as e:
raise e # Raise a proper error

View File

@ -4,8 +4,6 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
from invokeai.app.services.style_preset_images.style_preset_images_base import StylePresetImageFileStorageBase
from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase
if TYPE_CHECKING:
from logging import Logger
@ -63,8 +61,6 @@ class InvocationServices:
workflow_records: "WorkflowRecordsStorageBase",
tensors: "ObjectSerializerBase[torch.Tensor]",
conditioning: "ObjectSerializerBase[ConditioningFieldData]",
style_preset_records: "StylePresetRecordsStorageBase",
style_preset_image_files: "StylePresetImageFileStorageBase",
):
self.board_images = board_images
self.board_image_records = board_image_records
@ -89,5 +85,3 @@ class InvocationServices:
self.workflow_records = workflow_records
self.tensors = tensors
self.conditioning = conditioning
self.style_preset_records = style_preset_records
self.style_preset_image_files = style_preset_image_files

View File

@ -783,9 +783,8 @@ class ModelInstallService(ModelInstallServiceBase):
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
if subfolder:
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
path_to_remove = top / subfolder # sdxl-turbo/vae/
subfolder_rename = subfolder.name.replace("/", "_").replace("\\", "_")
path_to_add = Path(f"{top}_{subfolder_rename}")
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
path_to_add = Path(f"{top}_{subfolder}")
else:
path_to_remove = Path(".")
path_to_add = Path(".")

View File

@ -77,7 +77,6 @@ class ModelRecordChanges(BaseModelExcludeNull):
type: Optional[ModelType] = Field(description="Type of model", default=None)
key: Optional[str] = Field(description="Database ID for this model", default=None)
hash: Optional[str] = Field(description="hash of model file", default=None)
format: Optional[str] = Field(description="format of model file", default=None)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
description="Default settings for this model", default=None

View File

@ -16,7 +16,6 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@ -50,7 +49,6 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_11(app_config=config, logger=logger))
migrator.register_migration(build_migration_12(app_config=config))
migrator.register_migration(build_migration_13())
migrator.register_migration(build_migration_14())
migrator.run_migrations()
return db

View File

@ -1,61 +0,0 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration14Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._create_style_presets(cursor)
def _create_style_presets(self, cursor: sqlite3.Cursor) -> None:
"""Create the table used to store style presets."""
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS style_presets (
id TEXT NOT NULL PRIMARY KEY,
name TEXT NOT NULL,
preset_data TEXT NOT NULL,
type TEXT NOT NULL DEFAULT "user",
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
);
"""
]
# Add trigger for `updated_at`.
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS style_presets
AFTER UPDATE
ON style_presets FOR EACH ROW
BEGIN
UPDATE style_presets SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
]
# Add indexes for searchable fields
indices = [
"CREATE INDEX IF NOT EXISTS idx_style_presets_name ON style_presets(name);",
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def build_migration_14() -> Migration:
"""
Build the migration from database version 13 to 14..
This migration does the following:
- Create the table used to store style presets.
"""
migration_14 = Migration(
from_version=13,
to_version=14,
callback=Migration14Callback(),
)
return migration_14

Binary file not shown.

Before

Width:  |  Height:  |  Size: 98 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 122 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 123 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 160 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 146 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 119 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 117 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 46 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 156 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 141 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 96 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 88 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 107 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 132 KiB

View File

@ -1,33 +0,0 @@
from abc import ABC, abstractmethod
from pathlib import Path
from PIL.Image import Image as PILImageType
class StylePresetImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files."""
@abstractmethod
def get(self, style_preset_id: str) -> PILImageType:
"""Retrieves a style preset image as PIL Image."""
pass
@abstractmethod
def get_path(self, style_preset_id: str) -> Path:
"""Gets the internal path to a style preset image."""
pass
@abstractmethod
def get_url(self, style_preset_id: str) -> str | None:
"""Gets the URL to fetch a style preset image."""
pass
@abstractmethod
def save(self, style_preset_id: str, image: PILImageType) -> None:
"""Saves a style preset image."""
pass
@abstractmethod
def delete(self, style_preset_id: str) -> None:
"""Deletes a style preset image."""
pass

View File

@ -1,19 +0,0 @@
class StylePresetImageFileNotFoundException(Exception):
"""Raised when an image file is not found in storage."""
def __init__(self, message: str = "Style preset image file not found"):
super().__init__(message)
class StylePresetImageFileSaveException(Exception):
"""Raised when an image cannot be saved."""
def __init__(self, message: str = "Style preset image file not saved"):
super().__init__(message)
class StylePresetImageFileDeleteException(Exception):
"""Raised when an image cannot be deleted."""
def __init__(self, message: str = "Style preset image file not deleted"):
super().__init__(message)

View File

@ -1,88 +0,0 @@
from pathlib import Path
from PIL import Image
from PIL.Image import Image as PILImageType
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.style_preset_images.style_preset_images_base import StylePresetImageFileStorageBase
from invokeai.app.services.style_preset_images.style_preset_images_common import (
StylePresetImageFileDeleteException,
StylePresetImageFileNotFoundException,
StylePresetImageFileSaveException,
)
from invokeai.app.services.style_preset_records.style_preset_records_common import PresetType
from invokeai.app.util.misc import uuid_string
from invokeai.app.util.thumbnails import make_thumbnail
class StylePresetImageFileStorageDisk(StylePresetImageFileStorageBase):
"""Stores images on disk"""
def __init__(self, style_preset_images_folder: Path):
self._style_preset_images_folder = style_preset_images_folder
self._validate_storage_folders()
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
def get(self, style_preset_id: str) -> PILImageType:
try:
path = self.get_path(style_preset_id)
return Image.open(path)
except FileNotFoundError as e:
raise StylePresetImageFileNotFoundException from e
def save(self, style_preset_id: str, image: PILImageType) -> None:
try:
self._validate_storage_folders()
image_path = self._style_preset_images_folder / (style_preset_id + ".webp")
thumbnail = make_thumbnail(image, 256)
thumbnail.save(image_path, format="webp")
except Exception as e:
raise StylePresetImageFileSaveException from e
def get_path(self, style_preset_id: str) -> Path:
style_preset = self._invoker.services.style_preset_records.get(style_preset_id)
if style_preset.type is PresetType.Default:
default_images_dir = Path(__file__).parent / Path("default_style_preset_images")
path = default_images_dir / (style_preset.name + ".png")
else:
path = self._style_preset_images_folder / (style_preset_id + ".webp")
return path
def get_url(self, style_preset_id: str) -> str | None:
path = self.get_path(style_preset_id)
if not self._validate_path(path):
return
url = self._invoker.services.urls.get_style_preset_image_url(style_preset_id)
# The image URL never changes, so we must add random query string to it to prevent caching
url += f"?{uuid_string()}"
return url
def delete(self, style_preset_id: str) -> None:
try:
path = self.get_path(style_preset_id)
if not self._validate_path(path):
raise StylePresetImageFileNotFoundException
path.unlink()
except StylePresetImageFileNotFoundException as e:
raise StylePresetImageFileNotFoundException from e
except Exception as e:
raise StylePresetImageFileDeleteException from e
def _validate_path(self, path: Path) -> bool:
"""Validates the path given for an image."""
return path.exists()
def _validate_storage_folders(self) -> None:
"""Checks if the required folders exist and create them if they don't"""
self._style_preset_images_folder.mkdir(parents=True, exist_ok=True)

View File

@ -1,146 +0,0 @@
[
{
"name": "Photography (General)",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt}. photography. f/2.8 macro photo, bokeh, photorealism",
"negative_prompt": "painting, digital art. sketch, blurry"
}
},
{
"name": "Photography (Studio Lighting)",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt}, photography. f/8 photo. centered subject, studio lighting.",
"negative_prompt": "painting, digital art. sketch, blurry"
}
},
{
"name": "Photography (Landscape)",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt}, landscape photograph, f/12, lifelike, highly detailed.",
"negative_prompt": "painting, digital art. sketch, blurry"
}
},
{
"name": "Photography (Portrait)",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt}. photography. portraiture. catch light in eyes. one flash. rembrandt lighting. Soft box. dark shadows. High contrast. 80mm lens. F2.8.",
"negative_prompt": "painting, digital art. sketch, blurry"
}
},
{
"name": "Photography (Black and White)",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} photography. natural light. 80mm lens. F1.4. strong contrast, hard light. dark contrast. blurred background. black and white",
"negative_prompt": "painting, digital art. sketch, colour+"
}
},
{
"name": "Architectural Visualization",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt}. architectural photography, f/12, luxury, aesthetically pleasing form and function.",
"negative_prompt": "painting, digital art. sketch, blurry"
}
},
{
"name": "Concept Art (Fantasy)",
"type": "default",
"preset_data": {
"positive_prompt": "concept artwork of a {prompt}. (digital painterly art style)++, mythological, (textured 2d dry media brushpack)++, glazed brushstrokes, otherworldly. painting+, illustration+",
"negative_prompt": "photo. distorted, blurry, out of focus. sketch. (cgi, 3d.)++"
}
},
{
"name": "Concept Art (Sci-Fi)",
"type": "default",
"preset_data": {
"positive_prompt": "(concept art)++, {prompt}, (sleek futurism)++, (textured 2d dry media)++, metallic highlights, digital painting style",
"negative_prompt": "photo. distorted, blurry, out of focus. sketch. (cgi, 3d.)++"
}
},
{
"name": "Concept Art (Character)",
"type": "default",
"preset_data": {
"positive_prompt": "(character concept art)++, stylized painterly digital painting of {prompt}, (painterly, impasto. Dry brush.)++",
"negative_prompt": "photo. distorted, blurry, out of focus. sketch. (cgi, 3d.)++"
}
},
{
"name": "Concept Art (Painterly)",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} oil painting. high contrast. impasto. sfumato. chiaroscuro. Palette knife.",
"negative_prompt": "photo. smooth. border. frame"
}
},
{
"name": "Environment Art",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} environment artwork, hyper-realistic digital painting style with cinematic composition, atmospheric, depth and detail, voluminous. textured dry brush 2d media",
"negative_prompt": "photo, distorted, blurry, out of focus. sketch."
}
},
{
"name": "Interior Design (Visualization)",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} interior design photo, gentle shadows, light mid-tones, dimension, mix of smooth and textured surfaces, focus on negative space and clean lines, focus",
"negative_prompt": "photo, distorted. sketch."
}
},
{
"name": "Product Rendering",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} high quality product photography, 3d rendering with key lighting, shallow depth of field, simple plain background, studio lighting.",
"negative_prompt": "blurry, sketch, messy, dirty. unfinished."
}
},
{
"name": "Sketch",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} black and white pencil drawing, off-center composition, cross-hatching for shadows, bold strokes, textured paper. sketch+++",
"negative_prompt": "blurry, photo, painting, color. messy, dirty. unfinished. frame, borders."
}
},
{
"name": "Line Art",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} Line art. bold outline. simplistic. white background. 2d",
"negative_prompt": "photo. digital art. greyscale. solid black. painting"
}
},
{
"name": "Anime",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} anime++, bold outline, cel-shaded coloring, shounen, seinen",
"negative_prompt": "(photo)+++. greyscale. solid black. painting"
}
},
{
"name": "Illustration",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} illustration, bold linework, illustrative details, vector art style, flat coloring",
"negative_prompt": "(photo)+++. greyscale. painting, black and white."
}
},
{
"name": "Vehicles",
"type": "default",
"preset_data": {
"positive_prompt": "A weird futuristic normal auto, {prompt} elegant design, nice color, nice wheels",
"negative_prompt": "sketch. digital art. greyscale. painting"
}
}
]

View File

@ -1,42 +0,0 @@
from abc import ABC, abstractmethod
from invokeai.app.services.style_preset_records.style_preset_records_common import (
PresetType,
StylePresetChanges,
StylePresetRecordDTO,
StylePresetWithoutId,
)
class StylePresetRecordsStorageBase(ABC):
"""Base class for style preset storage services."""
@abstractmethod
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
"""Get style preset by id."""
pass
@abstractmethod
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
"""Creates a style preset."""
pass
@abstractmethod
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
"""Creates many style presets."""
pass
@abstractmethod
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
"""Updates a style preset."""
pass
@abstractmethod
def delete(self, style_preset_id: str) -> None:
"""Deletes a style preset."""
pass
@abstractmethod
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
"""Gets many workflows."""
pass

View File

@ -1,139 +0,0 @@
import codecs
import csv
import json
from enum import Enum
from typing import Any, Optional
import pydantic
from fastapi import UploadFile
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter
from invokeai.app.util.metaenum import MetaEnum
class StylePresetNotFoundError(Exception):
"""Raised when a style preset is not found"""
class PresetData(BaseModel, extra="forbid"):
positive_prompt: str = Field(description="Positive prompt")
negative_prompt: str = Field(description="Negative prompt")
PresetDataValidator = TypeAdapter(PresetData)
class PresetType(str, Enum, metaclass=MetaEnum):
User = "user"
Default = "default"
Project = "project"
class StylePresetChanges(BaseModel, extra="forbid"):
name: Optional[str] = Field(default=None, description="The style preset's new name.")
preset_data: Optional[PresetData] = Field(default=None, description="The updated data for style preset.")
type: Optional[PresetType] = Field(description="The updated type of the style preset")
class StylePresetWithoutId(BaseModel):
name: str = Field(description="The name of the style preset.")
preset_data: PresetData = Field(description="The preset data")
type: PresetType = Field(description="The type of style preset")
class StylePresetRecordDTO(StylePresetWithoutId):
id: str = Field(description="The style preset ID.")
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "StylePresetRecordDTO":
data["preset_data"] = PresetDataValidator.validate_json(data.get("preset_data", ""))
return StylePresetRecordDTOValidator.validate_python(data)
StylePresetRecordDTOValidator = TypeAdapter(StylePresetRecordDTO)
class StylePresetRecordWithImage(StylePresetRecordDTO):
image: Optional[str] = Field(description="The path for image")
class StylePresetImportRow(BaseModel):
name: str = Field(min_length=1, description="The name of the preset.")
positive_prompt: str = Field(
default="",
description="The positive prompt for the preset.",
validation_alias=AliasChoices("positive_prompt", "prompt"),
)
negative_prompt: str = Field(default="", description="The negative prompt for the preset.")
model_config = ConfigDict(str_strip_whitespace=True, extra="forbid")
StylePresetImportList = list[StylePresetImportRow]
StylePresetImportListTypeAdapter = TypeAdapter(StylePresetImportList)
class UnsupportedFileTypeError(ValueError):
"""Raised when an unsupported file type is encountered"""
pass
class InvalidPresetImportDataError(ValueError):
"""Raised when invalid preset import data is encountered"""
pass
async def parse_presets_from_file(file: UploadFile) -> list[StylePresetWithoutId]:
"""Parses style presets from a file. The file must be a CSV or JSON file.
If CSV, the file must have the following columns:
- name
- prompt (or positive_prompt)
- negative_prompt
If JSON, the file must be a list of objects with the following keys:
- name
- prompt (or positive_prompt)
- negative_prompt
Args:
file (UploadFile): The file to parse.
Returns:
list[StylePresetWithoutId]: The parsed style presets.
Raises:
UnsupportedFileTypeError: If the file type is not supported.
InvalidPresetImportDataError: If the data in the file is invalid.
"""
if file.content_type not in ["text/csv", "application/json"]:
raise UnsupportedFileTypeError()
if file.content_type == "text/csv":
csv_reader = csv.DictReader(codecs.iterdecode(file.file, "utf-8"))
data = list(csv_reader)
else: # file.content_type == "application/json":
json_data = await file.read()
data = json.loads(json_data)
try:
imported_presets = StylePresetImportListTypeAdapter.validate_python(data)
style_presets: list[StylePresetWithoutId] = []
for imported in imported_presets:
preset_data = PresetData(positive_prompt=imported.positive_prompt, negative_prompt=imported.negative_prompt)
style_preset = StylePresetWithoutId(name=imported.name, preset_data=preset_data, type=PresetType.User)
style_presets.append(style_preset)
except pydantic.ValidationError as e:
if file.content_type == "text/csv":
msg = "Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt' and name cannot be blank"
else: # file.content_type == "application/json":
msg = "Invalid JSON format: must be a list of objects with keys 'name', 'prompt', and 'negative_prompt' and name cannot be blank"
raise InvalidPresetImportDataError(msg) from e
finally:
file.file.close()
return style_presets

View File

@ -1,215 +0,0 @@
import json
from pathlib import Path
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase
from invokeai.app.services.style_preset_records.style_preset_records_common import (
PresetType,
StylePresetChanges,
StylePresetNotFoundError,
StylePresetRecordDTO,
StylePresetWithoutId,
)
from invokeai.app.util.misc import uuid_string
class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
self._sync_default_style_presets()
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
"""Gets a style preset by ID."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM style_presets
WHERE id = ?;
""",
(style_preset_id,),
)
row = self._cursor.fetchone()
if row is None:
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
return StylePresetRecordDTO.from_dict(dict(row))
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
style_preset_id = uuid_string()
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO style_presets (
id,
name,
preset_data,
type
)
VALUES (?, ?, ?, ?);
""",
(
style_preset_id,
style_preset.name,
style_preset.preset_data.model_dump_json(),
style_preset.type,
),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return self.get(style_preset_id)
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
style_preset_ids = []
try:
self._lock.acquire()
for style_preset in style_presets:
style_preset_id = uuid_string()
style_preset_ids.append(style_preset_id)
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO style_presets (
id,
name,
preset_data,
type
)
VALUES (?, ?, ?, ?);
""",
(
style_preset_id,
style_preset.name,
style_preset.preset_data.model_dump_json(),
style_preset.type,
),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return None
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
try:
self._lock.acquire()
# Change the name of a style preset
if changes.name is not None:
self._cursor.execute(
"""--sql
UPDATE style_presets
SET name = ?
WHERE id = ?;
""",
(changes.name, style_preset_id),
)
# Change the preset data for a style preset
if changes.preset_data is not None:
self._cursor.execute(
"""--sql
UPDATE style_presets
SET preset_data = ?
WHERE id = ?;
""",
(changes.preset_data.model_dump_json(), style_preset_id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return self.get(style_preset_id)
def delete(self, style_preset_id: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE from style_presets
WHERE id = ?;
""",
(style_preset_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return None
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
try:
self._lock.acquire()
main_query = """
SELECT
*
FROM style_presets
"""
if type is not None:
main_query += "WHERE type = ? "
main_query += "ORDER BY LOWER(name) ASC"
if type is not None:
self._cursor.execute(main_query, (type,))
else:
self._cursor.execute(main_query)
rows = self._cursor.fetchall()
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
return style_presets
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
def _sync_default_style_presets(self) -> None:
"""Syncs default style presets to the database. Internal use only."""
# First delete all existing default style presets
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM style_presets
WHERE type = "default";
"""
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
# Next, parse and create the default style presets
with self._lock, open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
presets = json.load(file)
for preset in presets:
style_preset = StylePresetWithoutId.model_validate(preset)
self.create(style_preset)

View File

@ -13,8 +13,3 @@ class UrlServiceBase(ABC):
def get_model_image_url(self, model_key: str) -> str:
"""Gets the URL for a model image"""
pass
@abstractmethod
def get_style_preset_image_url(self, style_preset_id: str) -> str:
"""Gets the URL for a style preset image"""
pass

View File

@ -19,6 +19,3 @@ class LocalUrlService(UrlServiceBase):
def get_model_image_url(self, model_key: str) -> str:
return f"{self._base_url_v2}/models/i/{model_key}/image"
def get_style_preset_image_url(self, style_preset_id: str) -> str:
return f"{self._base_url}/style_presets/i/{style_preset_id}/image"

View File

@ -1,260 +0,0 @@
{
"name": "FLUX Text to Image",
"author": "InvokeAI",
"description": "A simple text-to-image workflow using FLUX dev or schnell models. Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
"version": "1.0.4",
"contact": "",
"tags": "text2image, flux",
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
"exposedFields": [
{
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "model"
},
{
"nodeId": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"fieldName": "prompt"
},
{
"nodeId": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"fieldName": "num_steps"
},
{
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "t5_encoder_model"
}
],
"meta": {
"version": "3.0.0",
"category": "default"
},
"nodes": [
{
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"type": "invocation",
"data": {
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"type": "flux_model_loader",
"version": "1.0.4",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": false,
"inputs": {
"model": {
"name": "model",
"label": ""
},
"t5_encoder_model": {
"name": "t5_encoder_model",
"label": ""
},
"clip_embed_model": {
"name": "clip_embed_model",
"label": ""
},
"vae_model": {
"name": "vae_model",
"label": ""
}
}
},
"position": {
"x": 381.1882713063478,
"y": -95.89663532854017
}
},
{
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"type": "invocation",
"data": {
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"type": "flux_text_encoder",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"inputs": {
"clip": {
"name": "clip",
"label": ""
},
"t5_encoder": {
"name": "t5_encoder",
"label": ""
},
"t5_max_seq_len": {
"name": "t5_max_seq_len",
"label": "T5 Max Seq Len",
"value": 256
},
"prompt": {
"name": "prompt",
"label": "",
"value": "a cat"
}
}
},
"position": {
"x": 824.1970602278849,
"y": 146.98251001061735
}
},
{
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
"type": "invocation",
"data": {
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
"type": "rand_int",
"version": "1.0.1",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": false,
"inputs": {
"low": {
"name": "low",
"label": "",
"value": 0
},
"high": {
"name": "high",
"label": "",
"value": 2147483647
}
}
},
"position": {
"x": 822.9899179655476,
"y": 360.9657214885052
}
},
{
"id": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"type": "invocation",
"data": {
"id": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"type": "flux_text_to_image",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": false,
"useCache": true,
"inputs": {
"board": {
"name": "board",
"label": ""
},
"metadata": {
"name": "metadata",
"label": ""
},
"transformer": {
"name": "transformer",
"label": ""
},
"vae": {
"name": "vae",
"label": ""
},
"positive_text_conditioning": {
"name": "positive_text_conditioning",
"label": ""
},
"width": {
"name": "width",
"label": "",
"value": 1024
},
"height": {
"name": "height",
"label": "",
"value": 1024
},
"num_steps": {
"name": "num_steps",
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
"value": 30
},
"guidance": {
"name": "guidance",
"label": "",
"value": 4
},
"seed": {
"name": "seed",
"label": "",
"value": 0
}
}
},
"position": {
"x": 1216.3900791301849,
"y": 5.500841807102248
}
}
],
"edges": [
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90max_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "max_seq_len",
"targetHandle": "t5_max_seq_len"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-159bdf1b-79e7-4174-b86e-d40e646964c8vae",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"sourceHandle": "vae",
"targetHandle": "vae"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90t5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "t5_encoder",
"targetHandle": "t5_encoder"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90clip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-159bdf1b-79e7-4174-b86e-d40e646964c8transformer",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"sourceHandle": "transformer",
"targetHandle": "transformer"
},
{
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-159bdf1b-79e7-4174-b86e-d40e646964c8positive_text_conditioning",
"type": "default",
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"sourceHandle": "conditioning",
"targetHandle": "positive_text_conditioning"
},
{
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-159bdf1b-79e7-4174-b86e-d40e646964c8seed",
"type": "default",
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"sourceHandle": "value",
"targetHandle": "seed"
}
]
}

View File

@ -81,7 +81,7 @@ def get_openapi_func(
# Add the output map to the schema
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
"type": "object",
"properties": dict(sorted(invocation_output_map_properties.items())),
"properties": invocation_output_map_properties,
"required": invocation_output_map_required,
}

View File

@ -1,32 +0,0 @@
# Initially pulled from https://github.com/black-forest-labs/flux
import torch
from einops import rearrange
from torch import Tensor
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
q, k = apply_rope(q, k, pe)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)")
return x
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.float()
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

View File

@ -1,117 +0,0 @@
# Initially pulled from https://github.com/black-forest-labs/flux
from dataclasses import dataclass
import torch
from torch import Tensor, nn
from invokeai.backend.flux.modules.layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
@dataclass
class FluxParams:
in_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list[int]
theta: int
qkv_bias: bool
guidance_embed: bool
class Flux(nn.Module):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, params: FluxParams):
super().__init__()
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
for _ in range(params.depth_single_blocks)
]
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
def forward(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img

View File

@ -1,310 +0,0 @@
# Initially pulled from https://github.com/black-forest-labs/flux
from dataclasses import dataclass
import torch
from einops import rearrange
from torch import Tensor, nn
@dataclass
class AutoEncoderParams:
resolution: int
in_channels: int
ch: int
out_ch: int
ch_mult: list[int]
num_res_blocks: int
z_channels: int
scale_factor: float
shift_factor: float
class AttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
def attention(self, h_: Tensor) -> Tensor:
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
def forward(self, x: Tensor) -> Tensor:
return x + self.proj_out(self.attention(x))
class ResnetBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h = x
h = self.norm1(h)
h = torch.nn.functional.silu(h)
h = self.conv1(h)
h = self.norm2(h)
h = torch.nn.functional.silu(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
class Downsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
# no asymmetric padding in torch conv, must do it ourselves
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x: Tensor):
pad = (0, 1, 0, 1)
x = nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x: Tensor):
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class Encoder(nn.Module):
def __init__(
self,
resolution: int,
in_channels: int,
ch: int,
ch_mult: list[int],
num_res_blocks: int,
z_channels: int,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
block_in = self.ch
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x: Tensor) -> Tensor:
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# end
h = self.norm_out(h)
h = torch.nn.functional.silu(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(
self,
ch: int,
out_ch: int,
ch_mult: list[int],
num_res_blocks: int,
in_channels: int,
resolution: int,
z_channels: int,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.ffactor = 2 ** (self.num_resolutions - 1)
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
# z to block_in
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z: Tensor) -> Tensor:
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = torch.nn.functional.silu(h)
h = self.conv_out(h)
return h
class DiagonalGaussian(nn.Module):
def __init__(self, sample: bool = True, chunk_dim: int = 1):
super().__init__()
self.sample = sample
self.chunk_dim = chunk_dim
def forward(self, z: Tensor) -> Tensor:
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
if self.sample:
std = torch.exp(0.5 * logvar)
return mean + std * torch.randn_like(mean)
else:
return mean
class AutoEncoder(nn.Module):
def __init__(self, params: AutoEncoderParams):
super().__init__()
self.encoder = Encoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.decoder = Decoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
out_ch=params.out_ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.reg = DiagonalGaussian()
self.scale_factor = params.scale_factor
self.shift_factor = params.shift_factor
def encode(self, x: Tensor) -> Tensor:
z = self.reg(self.encoder(x))
z = self.scale_factor * (z - self.shift_factor)
return z
def decode(self, z: Tensor) -> Tensor:
z = z / self.scale_factor + self.shift_factor
return self.decoder(z)
def forward(self, x: Tensor) -> Tensor:
return self.decode(self.encode(x))

View File

@ -1,33 +0,0 @@
# Initially pulled from https://github.com/black-forest-labs/flux
from torch import Tensor, nn
from transformers import PreTrainedModel, PreTrainedTokenizer
class HFEncoder(nn.Module):
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
super().__init__()
self.max_length = max_length
self.is_clip = is_clip
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
self.tokenizer = tokenizer
self.hf_module = encoder
self.hf_module = self.hf_module.eval().requires_grad_(False)
def forward(self, text: list[str]) -> Tensor:
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=False,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
outputs = self.hf_module(
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
attention_mask=None,
output_hidden_states=False,
)
return outputs[self.output_key]

View File

@ -1,253 +0,0 @@
# Initially pulled from https://github.com/black-forest-labs/flux
import math
from dataclasses import dataclass
import torch
from einops import rearrange
from torch import Tensor, nn
from invokeai.backend.flux.math import attention, rope
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: Tensor) -> Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(1)
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
t = time_factor * t
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
if torch.is_floating_point(t):
embedding = embedding.to(t)
return embedding
class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int):
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
self.silu = nn.SiLU()
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.scale = nn.Parameter(torch.ones(dim))
def forward(self, x: Tensor):
x_dtype = x.dtype
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms).to(dtype=x_dtype) * self.scale
class QKNorm(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.query_norm = RMSNorm(dim)
self.key_norm = RMSNorm(dim)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
q = self.query_norm(q)
k = self.key_norm(k)
return q.to(v), k.to(v)
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim)
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
qkv = self.qkv(x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
x = attention(q, k, v, pe=pe)
x = self.proj(x)
return x
@dataclass
class ModulationOut:
shift: Tensor
scale: Tensor
gate: Tensor
class Modulation(nn.Module):
def __init__(self, dim: int, double: bool):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
return (
ModulationOut(*out[:3]),
ModulationOut(*out[3:]) if self.is_double else None,
)
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
self.txt_mod = Modulation(hidden_size, double=True)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
q = torch.cat((txt_q, img_q), dim=2)
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
attn = attention(q, k, v, pe=pe)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
return img, txt
class SingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float | None = None,
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
# proj and mlp_out
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
self.norm = QKNorm(head_dim)
self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod.gate * output
class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x

View File

@ -1,167 +0,0 @@
# Initially pulled from https://github.com/black-forest-labs/flux
import math
from typing import Callable
import torch
from einops import rearrange, repeat
from torch import Tensor
from tqdm import tqdm
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.conditioner import HFEncoder
def get_noise(
num_samples: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
seed: int,
):
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
rand_device = "cpu"
rand_dtype = torch.float16
return torch.randn(
num_samples,
16,
# allow for packing
2 * math.ceil(height / 16),
2 * math.ceil(width / 16),
device=rand_device,
dtype=rand_dtype,
generator=torch.Generator(device=rand_device).manual_seed(seed),
).to(device=device, dtype=dtype)
def prepare(t5: HFEncoder, clip: HFEncoder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
bs, c, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
if isinstance(prompt, str):
prompt = [prompt]
txt = t5(prompt)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3)
vec = clip(prompt)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids.to(img.device),
"txt": txt.to(img.device),
"txt_ids": txt_ids.to(img.device),
"vec": vec.to(img.device),
}
def time_shift(mu: float, sigma: float, t: Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# eastimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
def denoise(
model: Flux,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
# sampling parameters
timesteps: list[float],
step_callback: Callable[[], None],
guidance: float = 4.0,
):
# guidance_vec is ignored for schnell.
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
img = img + (t_prev - t_curr) * pred
step_callback()
return img
def unpack(x: Tensor, height: int, width: int) -> Tensor:
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)
def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Convert an input image in latent space to patches for diffusion.
This implementation was extracted from:
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
Returns:
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
"""
bs, c, h, w = latent_img.shape
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
# Generate patch position ids.
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device, dtype=img.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device, dtype=img.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
return img, img_ids

View File

@ -1,71 +0,0 @@
# Initially pulled from https://github.com/black-forest-labs/flux
from dataclasses import dataclass
from typing import Dict, Literal
from invokeai.backend.flux.model import FluxParams
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams
@dataclass
class ModelSpec:
params: FluxParams
ae_params: AutoEncoderParams
ckpt_path: str | None
ae_path: str | None
repo_id: str | None
repo_flow: str | None
repo_ae: str | None
max_seq_lengths: Dict[str, Literal[256, 512]] = {
"flux-dev": 512,
"flux-schnell": 256,
}
ae_params = {
"flux": AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
)
}
params = {
"flux-dev": FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
),
"flux-schnell": FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=False,
),
}

View File

@ -52,7 +52,6 @@ class BaseModelType(str, Enum):
StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
Flux = "flux"
# Kandinsky2_1 = "kandinsky-2.1"
@ -67,9 +66,7 @@ class ModelType(str, Enum):
TextualInversion = "embedding"
IPAdapter = "ip_adapter"
CLIPVision = "clip_vision"
CLIPEmbed = "clip_embed"
T2IAdapter = "t2i_adapter"
T5Encoder = "t5_encoder"
SpandrelImageToImage = "spandrel_image_to_image"
@ -77,7 +74,6 @@ class SubModelType(str, Enum):
"""Submodel type."""
UNet = "unet"
Transformer = "transformer"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
Tokenizer = "tokenizer"
@ -108,9 +104,6 @@ class ModelFormat(str, Enum):
EmbeddingFile = "embedding_file"
EmbeddingFolder = "embedding_folder"
InvokeAI = "invokeai"
T5Encoder = "t5_encoder"
BnbQuantizedLlmInt8b = "bnb_quantized_int8b"
BnbQuantizednf4b = "bnb_quantized_nf4b"
class SchedulerPredictionType(str, Enum):
@ -193,9 +186,7 @@ class ModelConfigBase(BaseModel):
class CheckpointConfigBase(ModelConfigBase):
"""Model config for checkpoint-style models."""
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b] = Field(
description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint
)
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
config_path: str = Field(description="path to the checkpoint model config file")
converted_at: Optional[float] = Field(
description="When this model was last converted to diffusers", default_factory=time.time
@ -214,26 +205,6 @@ class LoRAConfigBase(ModelConfigBase):
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
class T5EncoderConfigBase(ModelConfigBase):
type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder
class T5EncoderConfig(T5EncoderConfigBase):
format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.T5Encoder.value}")
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase):
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.BnbQuantizedLlmInt8b.value}")
class LoRALyCORISConfig(LoRAConfigBase):
"""Model config for LoRA/Lycoris models."""
@ -258,6 +229,7 @@ class VAECheckpointConfig(CheckpointConfigBase):
"""Model config for standalone VAE models."""
type: Literal[ModelType.VAE] = ModelType.VAE
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
@ -296,6 +268,7 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase)
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
@ -344,21 +317,6 @@ class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase):
"""Model config for main checkpoint models."""
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.format = ModelFormat.BnbQuantizednf4b
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.BnbQuantizednf4b.value}")
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
"""Model config for main diffusers models."""
@ -392,17 +350,6 @@ class IPAdapterCheckpointConfig(IPAdapterBaseConfig):
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.Checkpoint.value}")
class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
"""Model config for Clip Embeddings."""
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}")
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
"""Model config for CLIPVision."""
@ -461,15 +408,12 @@ AnyModelConfig = Annotated[
Union[
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()],
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
@ -477,7 +421,6 @@ AnyModelConfig = Annotated[
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
]

View File

@ -72,7 +72,6 @@ class ModelLoader(ModelLoaderBase):
pass
config.path = str(self._get_model_path(config))
self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type))
loaded_model = self._load_model(config, submodel_type)
self._ram_cache.put(

View File

@ -193,6 +193,15 @@ class ModelCacheBase(ABC, Generic[T]):
"""
pass
@abstractmethod
def exists(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
) -> bool:
"""Return true if the model identified by key and submodel_type is in the cache."""
pass
@abstractmethod
def cache_size(self) -> int:
"""Get the total size of the models currently cached."""

View File

@ -1,6 +1,22 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
# TODO: Add Stalker's proper name to copyright
""" """
"""
Manage a RAM cache of diffusion/transformer models for fast switching.
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
grows larger than a preset maximum, then the least recently used
model will be cleared and (re)loaded from disk when next needed.
The cache returns context manager generators designed to load the
model into the GPU within the context, and unload outside the
context. Use like this:
cache = ModelCache(max_cache_size=7.5)
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
do_something_in_GPU(SD1,SD2)
"""
import gc
import math
@ -24,64 +40,45 @@ from invokeai.backend.model_manager.load.model_util import calc_model_size_by_da
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
# Size of a GB in bytes.
GB = 2**30
# Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0
# amount of GPU memory to hold in reserve for use by generations (GB)
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
# actual size of a gig
GIG = 1073741824
# Size of a MB in bytes.
MB = 2**20
class ModelCache(ModelCacheBase[AnyModel]):
"""A cache for managing models in memory.
The cache is based on two levels of model storage:
- execution_device: The device where most models are executed (typically "cuda", "mps", or "cpu").
- storage_device: The device where models are offloaded when not in active use (typically "cpu").
The model cache is based on the following assumptions:
- storage_device_mem_size > execution_device_mem_size
- disk_to_storage_device_transfer_time >> storage_device_to_execution_device_transfer_time
A copy of all models in the cache is always kept on the storage_device. A subset of the models also have a copy on
the execution_device.
Models are moved between the storage_device and the execution_device as necessary. Cache size limits are enforced
on both the storage_device and the execution_device. The execution_device cache uses a smallest-first offload
policy. The storage_device cache uses a least-recently-used (LRU) offload policy.
Note: Neither of these offload policies has really been compared against alternatives. It's likely that different
policies would be better, although the optimal policies are likely heavily dependent on usage patterns and HW
configuration.
The cache returns context manager generators designed to load the model into the execution device (often GPU) within
the context, and unload outside the context.
Example usage:
```
cache = ModelCache(max_cache_size=7.5, max_vram_cache_size=6.0)
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1:
do_something_on_gpu(SD1)
```
"""
"""Implementation of ModelCacheBase."""
def __init__(
self,
max_cache_size: float,
max_vram_cache_size: float,
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"),
precision: torch.dtype = torch.float16,
sequential_offload: bool = False,
lazy_offloading: bool = True,
sha_chunksize: int = 16777216,
log_memory_usage: bool = False,
logger: Optional[Logger] = None,
):
"""
Initialize the model RAM cache.
:param max_cache_size: Maximum size of the storage_device cache in GBs.
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
:param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded.
:param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
@ -89,6 +86,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
"""
# allow lazy offloading only when vram cache enabled
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self._precision: torch.dtype = precision
self._max_cache_size: float = max_cache_size
self._max_vram_cache_size: float = max_vram_cache_size
self._execution_device: torch.device = execution_device
@ -147,6 +145,15 @@ class ModelCache(ModelCacheBase[AnyModel]):
total += cache_record.size
return total
def exists(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
) -> bool:
"""Return true if the model identified by key and submodel_type is in the cache."""
key = self._make_cache_key(key, submodel_type)
return key in self._cached_models
def put(
self,
key: str,
@ -196,7 +203,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
# more stats
if self.stats:
stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GB)
self.stats.cache_size = int(self._max_cache_size * GIG)
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
@ -224,13 +231,10 @@ class ModelCache(ModelCacheBase[AnyModel]):
return model_key
def offload_unlocked_models(self, size_required: int) -> None:
"""Offload models from the execution_device to make room for size_required.
:param size_required: The amount of space to clear in the execution_device cache, in bytes.
"""
reserved = self._max_vram_cache_size * GB
"""Move any unused models from VRAM."""
reserved = self._max_vram_cache_size * GIG
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved:
break
@ -241,7 +245,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
cache_entry.loaded = False
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB"
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
)
TorchDevice.empty_cache()
@ -299,7 +303,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
self.logger.debug(
f"Moved model '{cache_entry.key}' from {source_device} to"
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
@ -322,14 +326,14 @@ class ModelCache(ModelCacheBase[AnyModel]):
f"Moving model '{cache_entry.key}' from {source_device} to"
f" {target_device} caused an unexpected change in VRAM usage. The model's"
" estimated size may be incorrect. Estimated model size:"
f" {(cache_entry.size/GB):.3f} GB.\n"
f" {(cache_entry.size/GIG):.3f} GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
def print_cuda_stats(self) -> None:
"""Log CUDA diagnostics."""
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB)
ram = "%4.2fG" % (self.cache_size() / GB)
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
ram = "%4.2fG" % (self.cache_size() / GIG)
in_ram_models = 0
in_vram_models = 0
@ -349,20 +353,17 @@ class ModelCache(ModelCacheBase[AnyModel]):
)
def make_room(self, size: int) -> None:
"""Make enough room in the cache to accommodate a new model of indicated size.
Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
external references to the model, there's nothing that the cache can do about it, and those models will not be
garbage-collected.
"""
"""Make enough room in the cache to accommodate a new model of indicated size."""
# calculate how much memory this model will require
# multiplier = 2 if self.precision==torch.float32 else 1
bytes_needed = size
maximum_size = self.max_cache_size * GB # stored in GB, convert to bytes
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
current_size = self.cache_size()
if current_size + bytes_needed > maximum_size:
self.logger.debug(
f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional"
f" {(bytes_needed/GB):.2f} GB"
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional"
f" {(bytes_needed/GIG):.2f} GB"
)
self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
@ -379,7 +380,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
if not cache_entry.locked:
self.logger.debug(
f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
)
current_size -= cache_entry.size
models_cleared += 1

View File

@ -1,234 +0,0 @@
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
"""Class for Flux model loading in InvokeAI."""
from pathlib import Path
from typing import Optional
import accelerate
import torch
from safetensors.torch import load_file
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.util import ae_params, params
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
CLIPEmbedDiffusersConfig,
MainBnbQuantized4bCheckpointConfig,
MainCheckpointConfig,
T5EncoderBnbQuantizedLlmInt8bConfig,
T5EncoderConfig,
VAECheckpointConfig,
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.util.silence_warnings import SilenceWarnings
try:
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
bnb_available = True
except ImportError:
bnb_available = False
app_config = get_config()
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.VAE, format=ModelFormat.Checkpoint)
class FluxVAELoader(ModelLoader):
"""Class to load VAE models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, VAECheckpointConfig):
raise ValueError("Only VAECheckpointConfig models are currently supported here.")
model_path = Path(config.path)
with SilenceWarnings():
model = AutoEncoder(ae_params[config.config_path])
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
model.to(dtype=self._torch_dtype)
return model
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers)
class ClipCheckpointModel(ModelLoader):
"""Class to load main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CLIPEmbedDiffusersConfig):
raise ValueError("Only CLIPEmbedDiffusersConfig models are currently supported here.")
match submodel_type:
case SubModelType.Tokenizer:
return CLIPTokenizer.from_pretrained(Path(config.path) / "tokenizer")
case SubModelType.TextEncoder:
return CLIPTextModel.from_pretrained(Path(config.path) / "text_encoder")
raise ValueError(
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.BnbQuantizedLlmInt8b)
class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
"""Class to load main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, T5EncoderBnbQuantizedLlmInt8bConfig):
raise ValueError("Only T5EncoderBnbQuantizedLlmInt8bConfig models are currently supported here.")
if not bnb_available:
raise ImportError(
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
)
match submodel_type:
case SubModelType.Tokenizer2:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2:
te2_model_path = Path(config.path) / "text_encoder_2"
model_config = AutoConfig.from_pretrained(te2_model_path)
with accelerate.init_empty_weights():
model = AutoModelForTextEncoding.from_config(model_config)
model = quantize_model_llm_int8(model, modules_to_not_convert=set())
state_dict_path = te2_model_path / "bnb_llm_int8_model.safetensors"
state_dict = load_file(state_dict_path)
self._load_state_dict_into_t5(model, state_dict)
return model
raise ValueError(
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)
@classmethod
def _load_state_dict_into_t5(cls, model: T5EncoderModel, state_dict: dict[str, torch.Tensor]):
# There is a shared reference to a single weight tensor in the model.
# Both "encoder.embed_tokens.weight" and "shared.weight" refer to the same tensor, so only the latter should
# be present in the state_dict.
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True)
assert len(unexpected_keys) == 0
assert set(missing_keys) == {"encoder.embed_tokens.weight"}
# Assert that the layers we expect to be shared are actually shared.
assert model.encoder.embed_tokens.weight is model.shared.weight
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
class T5EncoderCheckpointModel(ModelLoader):
"""Class to load main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, T5EncoderConfig):
raise ValueError("Only T5EncoderConfig models are currently supported here.")
match submodel_type:
case SubModelType.Tokenizer2:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2:
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2")
raise ValueError(
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint)
class FluxCheckpointModel(ModelLoader):
"""Class to load main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
match submodel_type:
case SubModelType.Transformer:
return self._load_from_singlefile(config)
raise ValueError(
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)
def _load_from_singlefile(
self,
config: AnyModelConfig,
) -> AnyModel:
assert isinstance(config, MainCheckpointConfig)
model_path = Path(config.path)
with SilenceWarnings():
model = Flux(params[config.config_path])
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
return model
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b)
class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
"""Class to load main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
match submodel_type:
case SubModelType.Transformer:
return self._load_from_singlefile(config)
raise ValueError(
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)
def _load_from_singlefile(
self,
config: AnyModelConfig,
) -> AnyModel:
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
if not bnb_available:
raise ImportError(
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
)
model_path = Path(config.path)
with SilenceWarnings():
with accelerate.init_empty_weights():
model = Flux(params[config.config_path])
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
return model

View File

@ -78,12 +78,7 @@ class GenericDiffusersLoader(ModelLoader):
# TO DO: Add exception handling
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
if module in [
"diffusers",
"transformers",
"invokeai.backend.quantization.fast_quantized_transformers_model",
"invokeai.backend.quantization.fast_quantized_diffusion_model",
]:
if module in ["diffusers", "transformers"]:
res_type = sys.modules[module]
else:
res_type = sys.modules["diffusers"].pipelines

View File

@ -36,18 +36,8 @@ VARIANT_TO_IN_CHANNEL_MAP = {
}
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers
)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Checkpoint
)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint)
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
"""Class to load main models."""

View File

@ -9,7 +9,7 @@ from typing import Optional
import torch
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast
from transformers import CLIPTokenizer
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
@ -50,17 +50,6 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
),
):
return model.calc_size()
elif isinstance(
model,
(
T5TokenizerFast,
T5Tokenizer,
),
):
# HACK(ryand): len(model) just returns the vocabulary size, so this is blatantly wrong. It should be small
# relative to the text encoder that it's used with, so shouldn't matter too much, but we should fix this at some
# point.
return len(model)
else:
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
# supported model types.

View File

@ -95,7 +95,6 @@ class ModelProbe(object):
}
CLASS2TYPE = {
"FluxPipeline": ModelType.Main,
"StableDiffusionPipeline": ModelType.Main,
"StableDiffusionInpaintPipeline": ModelType.Main,
"StableDiffusionXLPipeline": ModelType.Main,
@ -107,7 +106,6 @@ class ModelProbe(object):
"ControlNetModel": ModelType.ControlNet,
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
"T2IAdapter": ModelType.T2IAdapter,
"CLIPModel": ModelType.CLIPEmbed,
}
@classmethod
@ -163,7 +161,7 @@ class ModelProbe(object):
fields["description"] = (
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
)
fields["format"] = ModelFormat(fields.get("format")) if "format" in fields else probe.get_format()
fields["format"] = fields.get("format") or probe.get_format()
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
fields["default_settings"] = fields.get("default_settings")
@ -178,10 +176,10 @@ class ModelProbe(object):
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
# additional fields needed for main and controlnet models
if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] and fields["format"] in [
ModelFormat.Checkpoint,
ModelFormat.BnbQuantizednf4b,
]:
if (
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE]
and fields["format"] is ModelFormat.Checkpoint
):
ckpt_config_path = cls._get_checkpoint_config_path(
model_path,
model_type=fields["type"],
@ -224,8 +222,7 @@ class ModelProbe(object):
ckpt = ckpt.get("state_dict", ckpt)
for key in [str(k) for k in ckpt.keys()]:
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.", "double_blocks.")):
# Keys starting with double_blocks are associated with Flux models
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.")):
return ModelType.Main
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
return ModelType.VAE
@ -324,27 +321,10 @@ class ModelProbe(object):
return possible_conf.absolute()
if model_type is ModelType.Main:
if base_type == BaseModelType.Flux:
# TODO: Decide between dev/schnell
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
state_dict = checkpoint.get("state_dict") or checkpoint
if "guidance_in.out_layer.weight" in state_dict:
# For flux, this is a key in invokeai.backend.flux.util.params
# Due to model type and format being the descriminator for model configs this
# is used rather than attempting to support flux with separate model types and format
# If changed in the future, please fix me
config_file = "flux-dev"
else:
# For flux, this is a key in invokeai.backend.flux.util.params
# Due to model type and format being the descriminator for model configs this
# is used rather than attempting to support flux with separate model types and format
# If changed in the future, please fix me
config_file = "flux-schnell"
else:
config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
config_file = config_file[prediction_type]
config_file = f"stable-diffusion/{config_file}"
config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
config_file = config_file[prediction_type]
config_file = f"stable-diffusion/{config_file}"
elif model_type is ModelType.ControlNet:
config_file = (
"controlnet/cldm_v15.yaml"
@ -353,13 +333,7 @@ class ModelProbe(object):
)
elif model_type is ModelType.VAE:
config_file = (
# For flux, this is a key in invokeai.backend.flux.util.ae_params
# Due to model type and format being the descriminator for model configs this
# is used rather than attempting to support flux with separate model types and format
# If changed in the future, please fix me
"flux"
if base_type is BaseModelType.Flux
else "stable-diffusion/v1-inference.yaml"
"stable-diffusion/v1-inference.yaml"
if base_type is BaseModelType.StableDiffusion1
else "stable-diffusion/sd_xl_base.yaml"
if base_type is BaseModelType.StableDiffusionXL
@ -442,15 +416,11 @@ class CheckpointProbeBase(ProbeBase):
self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
def get_format(self) -> ModelFormat:
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
if "double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict:
return ModelFormat.BnbQuantizednf4b
return ModelFormat("checkpoint")
def get_variant_type(self) -> ModelVariantType:
model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint)
base_type = self.get_base_type()
if model_type != ModelType.Main or base_type == BaseModelType.Flux:
if model_type != ModelType.Main:
return ModelVariantType.Normal
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
@ -470,8 +440,6 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
state_dict = self.checkpoint.get("state_dict") or checkpoint
if "double_blocks.0.img_attn.norm.key_norm.scale" in state_dict:
return BaseModelType.Flux
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
@ -514,7 +482,6 @@ class VaeCheckpointProbe(CheckpointProbeBase):
(r"xl", BaseModelType.StableDiffusionXL),
(r"sd2", BaseModelType.StableDiffusion2),
(r"vae", BaseModelType.StableDiffusion1),
(r"FLUX.1-schnell_ae", BaseModelType.Flux),
]:
if re.search(regexp, self.model_path.name, re.IGNORECASE):
return basetype
@ -746,11 +713,6 @@ class TextualInversionFolderProbe(FolderProbeBase):
return TextualInversionCheckpointProbe(path).get_base_type()
class T5EncoderFolderProbe(FolderProbeBase):
def get_format(self) -> ModelFormat:
return ModelFormat.T5Encoder
class ONNXFolderProbe(PipelineFolderProbe):
def get_base_type(self) -> BaseModelType:
# Due to the way the installer is set up, the configuration file for safetensors
@ -843,11 +805,6 @@ class CLIPVisionFolderProbe(FolderProbeBase):
return BaseModelType.Any
class CLIPEmbedFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
return BaseModelType.Any
class SpandrelImageToImageFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
@ -878,10 +835,8 @@ ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T5Encoder, T5EncoderFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPEmbed, CLIPEmbedFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)

View File

@ -2,7 +2,7 @@ from typing import Optional
from pydantic import BaseModel
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType
from invokeai.backend.model_manager.config import BaseModelType, ModelType
class StarterModelWithoutDependencies(BaseModel):
@ -11,7 +11,6 @@ class StarterModelWithoutDependencies(BaseModel):
name: str
base: BaseModelType
type: ModelType
format: Optional[ModelFormat] = None
is_installed: bool = False
@ -52,76 +51,10 @@ cyberrealistic_negative = StarterModel(
type=ModelType.TextualInversion,
)
t5_base_encoder = StarterModel(
name="t5_base_encoder",
base=BaseModelType.Any,
source="InvokeAI/t5-v1_1-xxl::bfloat16",
description="T5-XXL text encoder (used in FLUX pipelines). ~8GB",
type=ModelType.T5Encoder,
)
t5_8b_quantized_encoder = StarterModel(
name="t5_bnb_int8_quantized_encoder",
base=BaseModelType.Any,
source="InvokeAI/t5-v1_1-xxl::bnb_llm_int8",
description="T5-XXL text encoder with bitsandbytes LLM.int8() quantization (used in FLUX pipelines). ~5GB",
type=ModelType.T5Encoder,
format=ModelFormat.BnbQuantizedLlmInt8b,
)
clip_l_encoder = StarterModel(
name="clip-vit-large-patch14",
base=BaseModelType.Any,
source="InvokeAI/clip-vit-large-patch14-text-encoder::bfloat16",
description="CLIP-L text encoder (used in FLUX pipelines). ~250MB",
type=ModelType.CLIPEmbed,
)
flux_vae = StarterModel(
name="FLUX.1-schnell_ae",
base=BaseModelType.Flux,
source="black-forest-labs/FLUX.1-schnell::ae.safetensors",
description="FLUX VAE compatible with both schnell and dev variants.",
type=ModelType.VAE,
)
# List of starter models, displayed on the frontend.
# The order/sort of this list is not changed by the frontend - set it how you want it here.
STARTER_MODELS: list[StarterModel] = [
# region: Main
StarterModel(
name="FLUX Schnell (Quantized)",
base=BaseModelType.Flux,
source="InvokeAI/flux_schnell::transformer/bnb_nf4/flux1-schnell-bnb_nf4.safetensors",
description="FLUX schnell transformer quantized to bitsandbytes NF4 format. Total size with dependencies: ~12GB",
type=ModelType.Main,
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
),
StarterModel(
name="FLUX Dev (Quantized)",
base=BaseModelType.Flux,
source="InvokeAI/flux_dev::transformer/bnb_nf4/flux1-dev-bnb_nf4.safetensors",
description="FLUX dev transformer quantized to bitsandbytes NF4 format. Total size with dependencies: ~12GB",
type=ModelType.Main,
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
),
StarterModel(
name="FLUX Schnell",
base=BaseModelType.Flux,
source="InvokeAI/flux_schnell::transformer/base/flux1-schnell.safetensors",
description="FLUX schnell transformer in bfloat16. Total size with dependencies: ~33GB",
type=ModelType.Main,
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
),
StarterModel(
name="FLUX Dev",
base=BaseModelType.Flux,
source="InvokeAI/flux_dev::transformer/base/flux1-dev.safetensors",
description="FLUX dev transformer in bfloat16. Total size with dependencies: ~33GB",
type=ModelType.Main,
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
),
StarterModel(
name="CyberRealistic v4.1",
base=BaseModelType.StableDiffusion1,
@ -192,7 +125,6 @@ STARTER_MODELS: list[StarterModel] = [
# endregion
# region VAE
sdxl_fp16_vae_fix,
flux_vae,
# endregion
# region LoRA
StarterModel(
@ -518,11 +450,6 @@ STARTER_MODELS: list[StarterModel] = [
type=ModelType.SpandrelImageToImage,
),
# endregion
# region TextEncoders
t5_base_encoder,
t5_8b_quantized_encoder,
clip_l_encoder,
# endregion
]
assert len(STARTER_MODELS) == len({m.source for m in STARTER_MODELS}), "Duplicate starter models"

View File

@ -54,7 +54,6 @@ def filter_files(
"lora_weights.safetensors",
"weights.pb",
"onnx_data",
"spiece.model", # Added for `black-forest-labs/FLUX.1-schnell`.
)
):
paths.append(file)
@ -63,13 +62,13 @@ def filter_files(
# downloading random checkpoints that might also be in the repo. However there is no guarantee
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
# will adhere to this naming convention, so this is an area to be careful of.
elif re.search(r"model.*\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
paths.append(file)
# limit search to subfolder if requested
if subfolder:
subfolder = root / subfolder
paths = [x for x in paths if Path(subfolder) in x.parents]
paths = [x for x in paths if x.parent == Path(subfolder)]
# _filter_by_variant uniquifies the paths and returns a set
return sorted(_filter_by_variant(paths, variant))
@ -98,9 +97,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
if variant == ModelRepoVariant.Flax:
result.add(path)
# Note: '.model' was added to support:
# https://huggingface.co/black-forest-labs/FLUX.1-schnell/blob/768d12a373ed5cc9ef9a9dea7504dc09fcc14842/tokenizer_2/spiece.model
elif path.suffix in [".json", ".txt", ".model"]:
elif path.suffix in [".json", ".txt"]:
result.add(path)
elif variant in [
@ -143,23 +140,6 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
continue
for candidate_list in subfolder_weights.values():
# Check if at least one of the files has the explicit fp16 variant.
at_least_one_fp16 = False
for candidate in candidate_list:
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0] == ".fp16":
at_least_one_fp16 = True
break
if not at_least_one_fp16:
# If none of the candidates in this candidate_list have the explicit fp16 variant label, then this
# candidate_list probably doesn't adhere to the variant naming convention that we expected. In this case,
# we'll simply keep all the candidates. An example of a model that hits this case is
# `black-forest-labs/FLUX.1-schnell` (as of commit 012d2fd).
for candidate in candidate_list:
result.add(candidate.path)
# The candidate_list seems to have the expected variant naming convention. We'll select the highest scoring
# candidate.
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
if highest_score_candidate:
result.add(highest_score_candidate.path)

View File

@ -1,135 +0,0 @@
import bitsandbytes as bnb
import torch
# This file contains utils for working with models that use bitsandbytes LLM.int8() quantization.
# The utils in this file are partially inspired by:
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much
# cleaner by re-implementing bnb.nn.Linear8bitLt with proper use of buffers and less magic. But, for now, we try to
# stick close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
class InvokeInt8Params(bnb.nn.Int8Params):
"""We override cuda() to avoid re-quantizing the weights in the following cases:
- We loaded quantized weights from a state_dict on the cpu, and then moved the model to the gpu.
- We are moving the model back-and-forth between the cpu and gpu.
"""
def cuda(self, device):
if self.has_fp16_weights:
return super().cuda(device)
elif self.CB is not None and self.SCB is not None:
self.data = self.data.cuda()
self.CB = self.data
self.SCB = self.SCB.cuda()
else:
# we store the 8-bit rows-major weight
# we convert this weight to the turning/ampere weight during the first inference pass
B = self.data.contiguous().half().cuda(device)
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
del CBt
del SCBt
self.data = CB
self.CB = CB
self.SCB = SCB
return self
class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
def _load_from_state_dict(
self,
state_dict: dict[str, torch.Tensor],
prefix: str,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
weight = state_dict.pop(prefix + "weight")
bias = state_dict.pop(prefix + "bias", None)
# See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format.
scb = state_dict.pop(prefix + "SCB", None)
# Currently, we only support weight_format=0.
weight_format = state_dict.pop(prefix + "weight_format", None)
assert weight_format == 0
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
# rather than raising an exception to correctly implement this API.
assert len(state_dict) == 0
if scb is not None:
# We are loading a pre-quantized state dict.
self.weight = InvokeInt8Params(
data=weight,
requires_grad=self.weight.requires_grad,
has_fp16_weights=False,
# Note: After quantization, CB is the same as weight.
CB=weight,
SCB=scb,
)
self.bias = bias if bias is None else torch.nn.Parameter(bias)
else:
# We are loading a non-quantized state dict.
# We could simply call the `super()._load_from_state_dict()` method here, but then we wouldn't be able to
# load from a state_dict into a model on the "meta" device. Attempting to load into a model on the "meta"
# device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()`
# implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new
# `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done.
self.weight = InvokeInt8Params(
data=weight,
requires_grad=self.weight.requires_grad,
has_fp16_weights=False,
CB=None,
SCB=None,
)
self.bias = bias if bias is None else torch.nn.Parameter(bias)
# Reset the state. The persisted fields are based on the initialization behaviour in
# `bnb.nn.Linear8bitLt.__init__()`.
new_state = bnb.MatmulLtState()
new_state.threshold = self.state.threshold
new_state.has_fp16_weights = False
new_state.use_pool = self.state.use_pool
self.state = new_state
def _convert_linear_layers_to_llm_8bit(
module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""
) -> None:
"""Convert all linear layers in the module to bnb.nn.Linear8bitLt layers."""
for name, child in module.named_children():
fullname = f"{prefix}.{name}" if prefix else name
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
has_bias = child.bias is not None
replacement = InvokeLinear8bitLt(
child.in_features,
child.out_features,
bias=has_bias,
has_fp16_weights=False,
threshold=outlier_threshold,
)
replacement.weight.data = child.weight.data
if has_bias:
replacement.bias.data = child.bias.data
replacement.requires_grad_(False)
module.__setattr__(name, replacement)
else:
_convert_linear_layers_to_llm_8bit(
child, ignore_modules, outlier_threshold=outlier_threshold, prefix=fullname
)
def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[str], outlier_threshold: float = 6.0):
"""Apply bitsandbytes LLM.8bit() quantization to the model."""
_convert_linear_layers_to_llm_8bit(
module=model, ignore_modules=modules_to_not_convert, outlier_threshold=outlier_threshold
)
return model

View File

@ -1,156 +0,0 @@
import bitsandbytes as bnb
import torch
# This file contains utils for working with models that use bitsandbytes NF4 quantization.
# The utils in this file are partially inspired by:
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much
# cleaner by re-implementing bnb.nn.LinearNF4 with proper use of buffers and less magic. But, for now, we try to stick
# close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
class InvokeLinearNF4(bnb.nn.LinearNF4):
"""A class that extends `bnb.nn.LinearNF4` to add the following functionality:
- Ability to load Linear NF4 layers from a pre-quantized state_dict.
- Ability to load Linear NF4 layers from a state_dict when the model is on the "meta" device.
"""
def _load_from_state_dict(
self,
state_dict: dict[str, torch.Tensor],
prefix: str,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
"""This method is based on the logic in the bitsandbytes serialization unit tests for `Linear4bit`:
https://github.com/bitsandbytes-foundation/bitsandbytes/blob/6d714a5cce3db5bd7f577bc447becc7a92d5ccc7/tests/test_linear4bit.py#L52-L71
"""
weight = state_dict.pop(prefix + "weight")
bias = state_dict.pop(prefix + "bias", None)
# We expect the remaining keys to be quant_state keys.
quant_state_sd = state_dict
# During serialization, the quant_state is stored as subkeys of "weight." (See
# `bnb.nn.LinearNF4._save_to_state_dict()`). We validate that they at least have the correct prefix.
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
# rather than raising an exception to correctly implement this API.
assert all(k.startswith(prefix + "weight.") for k in quant_state_sd.keys())
if len(quant_state_sd) > 0:
# We are loading a pre-quantized state dict.
self.weight = bnb.nn.Params4bit.from_prequantized(
data=weight, quantized_stats=quant_state_sd, device=weight.device
)
self.bias = bias if bias is None else torch.nn.Parameter(bias, requires_grad=False)
else:
# We are loading a non-quantized state dict.
# We could simply call the `super()._load_from_state_dict()` method here, but then we wouldn't be able to
# load from a state_dict into a model on the "meta" device. Attempting to load into a model on the "meta"
# device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()`
# implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new
# `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done.
self.weight = bnb.nn.Params4bit(
data=weight,
requires_grad=self.weight.requires_grad,
compress_statistics=self.weight.compress_statistics,
quant_type=self.weight.quant_type,
quant_storage=self.weight.quant_storage,
module=self,
)
self.bias = bias if bias is None else torch.nn.Parameter(bias)
def _replace_param(
param: torch.nn.Parameter | bnb.nn.Params4bit,
data: torch.Tensor,
) -> torch.nn.Parameter:
"""A helper function to replace the data of a model parameter with new data in a way that allows replacing params on
the "meta" device.
Supports both `torch.nn.Parameter` and `bnb.nn.Params4bit` parameters.
"""
if param.device.type == "meta":
# Doing `param.data = data` raises a RuntimeError if param.data was on the "meta" device, so we need to
# re-create the param instead of overwriting the data.
if isinstance(param, bnb.nn.Params4bit):
return bnb.nn.Params4bit(
data,
requires_grad=data.requires_grad,
quant_state=param.quant_state,
compress_statistics=param.compress_statistics,
quant_type=param.quant_type,
)
return torch.nn.Parameter(data, requires_grad=data.requires_grad)
param.data = data
return param
def _convert_linear_layers_to_nf4(
module: torch.nn.Module,
ignore_modules: set[str],
compute_dtype: torch.dtype,
compress_statistics: bool = False,
prefix: str = "",
) -> None:
"""Convert all linear layers in the model to NF4 quantized linear layers.
Args:
module: All linear layers in this module will be converted.
ignore_modules: A set of module prefixes to ignore when converting linear layers.
compute_dtype: The dtype to use for computation in the quantized linear layers.
compress_statistics: Whether to enable nested quantization (aka double quantization) where the quantization
constants from the first quantization are quantized again.
prefix: The prefix of the current module in the model. Used to call this function recursively.
"""
for name, child in module.named_children():
fullname = f"{prefix}.{name}" if prefix else name
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
has_bias = child.bias is not None
replacement = InvokeLinearNF4(
child.in_features,
child.out_features,
bias=has_bias,
compute_dtype=compute_dtype,
compress_statistics=compress_statistics,
)
if has_bias:
replacement.bias = _replace_param(replacement.bias, child.bias.data)
replacement.weight = _replace_param(replacement.weight, child.weight.data)
replacement.requires_grad_(False)
module.__setattr__(name, replacement)
else:
_convert_linear_layers_to_nf4(child, ignore_modules, compute_dtype=compute_dtype, prefix=fullname)
def quantize_model_nf4(model: torch.nn.Module, modules_to_not_convert: set[str], compute_dtype: torch.dtype):
"""Apply bitsandbytes nf4 quantization to the model.
You likely want to call this function inside a `accelerate.init_empty_weights()` context.
Example usage:
```
# Initialize the model from a config on the meta device.
with accelerate.init_empty_weights():
model = ModelClass.from_config(...)
# Add NF4 quantization linear layers to the model - still on the meta device.
with accelerate.init_empty_weights():
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.float16)
# Load a state_dict into the model. (Could be either a prequantized or non-quantized state_dict.)
model.load_state_dict(state_dict, strict=True, assign=True)
# Move the model to the "cuda" device. If the model was non-quantized, this is where the weight quantization takes
# place.
model.to("cuda")
```
"""
_convert_linear_layers_to_nf4(module=model, ignore_modules=modules_to_not_convert, compute_dtype=compute_dtype)
return model

View File

@ -1,79 +0,0 @@
from pathlib import Path
import accelerate
from safetensors.torch import load_file, save_file
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import params
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time
def main():
"""A script for quantizing a FLUX transformer model using the bitsandbytes LLM.int8() quantization method.
This script is primarily intended for reference. The script params (e.g. the model_path, modules_to_not_convert,
etc.) are hardcoded and would need to be modified for other use cases.
"""
# Load the FLUX transformer model onto the meta device.
model_path = Path(
"/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors"
)
with log_time("Intialize FLUX transformer on meta device"):
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
p = params["flux-schnell"]
# Initialize the model on the "meta" device.
with accelerate.init_empty_weights():
model = Flux(p)
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
modules_to_not_convert: set[str] = set()
model_int8_path = model_path.parent / "bnb_llm_int8.safetensors"
if model_int8_path.exists():
# The quantized model already exists, load it and return it.
print(f"A pre-quantized model already exists at '{model_int8_path}'. Attempting to load it...")
# Replace the linear layers with LLM.int8() quantized linear layers (still on the meta device).
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
with log_time("Load state dict into model"):
sd = load_file(model_int8_path)
model.load_state_dict(sd, strict=True, assign=True)
with log_time("Move model to cuda"):
model = model.to("cuda")
print(f"Successfully loaded pre-quantized model from '{model_int8_path}'.")
else:
# The quantized model does not exist, quantize the model and save it.
print(f"No pre-quantized model found at '{model_int8_path}'. Quantizing the model...")
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
with log_time("Load state dict into model"):
state_dict = load_file(model_path)
# TODO(ryand): Cast the state_dict to the appropriate dtype?
model.load_state_dict(state_dict, strict=True, assign=True)
with log_time("Move model to cuda and quantize"):
model = model.to("cuda")
with log_time("Save quantized model"):
model_int8_path.parent.mkdir(parents=True, exist_ok=True)
save_file(model.state_dict(), model_int8_path)
print(f"Successfully quantized and saved model to '{model_int8_path}'.")
assert isinstance(model, Flux)
return model
if __name__ == "__main__":
main()

View File

@ -1,96 +0,0 @@
import time
from contextlib import contextmanager
from pathlib import Path
import accelerate
import torch
from safetensors.torch import load_file, save_file
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import params
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
@contextmanager
def log_time(name: str):
"""Helper context manager to log the time taken by a block of code."""
start = time.time()
try:
yield None
finally:
end = time.time()
print(f"'{name}' took {end - start:.4f} secs")
def main():
"""A script for quantizing a FLUX transformer model using the bitsandbytes NF4 quantization method.
This script is primarily intended for reference. The script params (e.g. the model_path, modules_to_not_convert,
etc.) are hardcoded and would need to be modified for other use cases.
"""
model_path = Path(
"/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors"
)
# inference_dtype = torch.bfloat16
with log_time("Intialize FLUX transformer on meta device"):
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
p = params["flux-schnell"]
# Initialize the model on the "meta" device.
with accelerate.init_empty_weights():
model = Flux(p)
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
modules_to_not_convert: set[str] = set()
model_nf4_path = model_path.parent / "bnb_nf4.safetensors"
if model_nf4_path.exists():
# The quantized model already exists, load it and return it.
print(f"A pre-quantized model already exists at '{model_nf4_path}'. Attempting to load it...")
# Replace the linear layers with NF4 quantized linear layers (still on the meta device).
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
model = quantize_model_nf4(
model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
)
with log_time("Load state dict into model"):
state_dict = load_file(model_nf4_path)
model.load_state_dict(state_dict, strict=True, assign=True)
with log_time("Move model to cuda"):
model = model.to("cuda")
print(f"Successfully loaded pre-quantized model from '{model_nf4_path}'.")
else:
# The quantized model does not exist, quantize the model and save it.
print(f"No pre-quantized model found at '{model_nf4_path}'. Quantizing the model...")
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
model = quantize_model_nf4(
model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
)
with log_time("Load state dict into model"):
state_dict = load_file(model_path)
# TODO(ryand): Cast the state_dict to the appropriate dtype?
model.load_state_dict(state_dict, strict=True, assign=True)
with log_time("Move model to cuda and quantize"):
model = model.to("cuda")
with log_time("Save quantized model"):
model_nf4_path.parent.mkdir(parents=True, exist_ok=True)
save_file(model.state_dict(), model_nf4_path)
print(f"Successfully quantized and saved model to '{model_nf4_path}'.")
assert isinstance(model, Flux)
return model
if __name__ == "__main__":
main()

View File

@ -1,92 +0,0 @@
from pathlib import Path
import accelerate
from safetensors.torch import load_file, save_file
from transformers import AutoConfig, AutoModelForTextEncoding, T5EncoderModel
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time
def load_state_dict_into_t5(model: T5EncoderModel, state_dict: dict):
# There is a shared reference to a single weight tensor in the model.
# Both "encoder.embed_tokens.weight" and "shared.weight" refer to the same tensor, so only the latter should
# be present in the state_dict.
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True)
assert len(unexpected_keys) == 0
assert set(missing_keys) == {"encoder.embed_tokens.weight"}
# Assert that the layers we expect to be shared are actually shared.
assert model.encoder.embed_tokens.weight is model.shared.weight
def main():
"""A script for quantizing a T5 text encoder model using the bitsandbytes LLM.int8() quantization method.
This script is primarily intended for reference. The script params (e.g. the model_path, modules_to_not_convert,
etc.) are hardcoded and would need to be modified for other use cases.
"""
model_path = Path("/data/misc/text_encoder_2")
with log_time("Intialize T5 on meta device"):
model_config = AutoConfig.from_pretrained(model_path)
with accelerate.init_empty_weights():
model = AutoModelForTextEncoding.from_config(model_config)
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
modules_to_not_convert: set[str] = set()
model_int8_path = model_path / "bnb_llm_int8.safetensors"
if model_int8_path.exists():
# The quantized model already exists, load it and return it.
print(f"A pre-quantized model already exists at '{model_int8_path}'. Attempting to load it...")
# Replace the linear layers with LLM.int8() quantized linear layers (still on the meta device).
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
with log_time("Load state dict into model"):
sd = load_file(model_int8_path)
load_state_dict_into_t5(model, sd)
with log_time("Move model to cuda"):
model = model.to("cuda")
print(f"Successfully loaded pre-quantized model from '{model_int8_path}'.")
else:
# The quantized model does not exist, quantize the model and save it.
print(f"No pre-quantized model found at '{model_int8_path}'. Quantizing the model...")
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
with log_time("Load state dict into model"):
# Load sharded state dict.
files = list(model_path.glob("*.safetensors"))
state_dict = {}
for file in files:
sd = load_file(file)
state_dict.update(sd)
load_state_dict_into_t5(model, state_dict)
with log_time("Move model to cuda and quantize"):
model = model.to("cuda")
with log_time("Save quantized model"):
model_int8_path.parent.mkdir(parents=True, exist_ok=True)
state_dict = model.state_dict()
state_dict.pop("encoder.embed_tokens.weight")
save_file(state_dict, model_int8_path)
# This handling of shared weights could also be achieved with save_model(...), but then we'd lose control
# over which keys are kept. And, the corresponding load_model(...) function does not support assign=True.
# save_model(model, model_int8_path)
print(f"Successfully quantized and saved model to '{model_int8_path}'.")
assert isinstance(model, T5EncoderModel)
return model
if __name__ == "__main__":
main()

View File

@ -25,6 +25,11 @@ class BasicConditioningInfo:
return self
@dataclass
class ConditioningFieldData:
conditionings: List[BasicConditioningInfo]
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
"""SDXL text conditioning information produced by Compel."""
@ -38,22 +43,6 @@ class SDXLConditioningInfo(BasicConditioningInfo):
return super().to(device=device, dtype=dtype)
@dataclass
class FLUXConditioningInfo:
clip_embeds: torch.Tensor
t5_embeds: torch.Tensor
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self.clip_embeds = self.clip_embeds.to(device=device, dtype=dtype)
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
return self
@dataclass
class ConditioningFieldData:
conditionings: List[BasicConditioningInfo] | List[SDXLConditioningInfo] | List[FLUXConditioningInfo]
@dataclass
class IPAdapterConditioningInfo:
cond_image_prompt_embeds: torch.Tensor

View File

@ -3,9 +3,10 @@ Initialization file for invokeai.backend.util
"""
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util.util import Chdir, directory_size
from invokeai.backend.util.util import GIG, Chdir, directory_size
__all__ = [
"GIG",
"directory_size",
"Chdir",
"InvokeAILogger",

View File

@ -7,6 +7,9 @@ from pathlib import Path
from PIL import Image
# actual size of a gig
GIG = 1073741824
def slugify(value: str, allow_unicode: bool = False) -> str:
"""

View File

@ -53,63 +53,64 @@
},
"dependencies": {
"@chakra-ui/react-use-size": "^2.1.0",
"@dagrejs/dagre": "^1.1.3",
"@dagrejs/graphlib": "^2.2.3",
"@dagrejs/dagre": "^1.1.2",
"@dagrejs/graphlib": "^2.2.2",
"@dnd-kit/core": "^6.1.0",
"@dnd-kit/sortable": "^8.0.0",
"@dnd-kit/utilities": "^3.2.2",
"@fontsource-variable/inter": "^5.0.20",
"@invoke-ai/ui-library": "^0.0.29",
"@nanostores/react": "^0.7.3",
"@fontsource-variable/inter": "^5.0.18",
"@invoke-ai/ui-library": "^0.0.25",
"@nanostores/react": "^0.7.2",
"@reduxjs/toolkit": "2.2.3",
"@roarr/browser-log-writer": "^1.3.0",
"chakra-react-select": "^4.9.1",
"compare-versions": "^6.1.1",
"@xyflow/react": "^12.0.4",
"chakra-react-select": "^4.7.6",
"compare-versions": "^6.1.0",
"dateformat": "^5.0.3",
"fracturedjsonjs": "^4.0.2",
"framer-motion": "^11.3.24",
"i18next": "^23.12.2",
"i18next-http-backend": "^2.5.2",
"fracturedjsonjs": "^4.0.1",
"framer-motion": "^11.1.8",
"i18next": "^23.11.3",
"i18next-http-backend": "^2.5.1",
"idb-keyval": "^6.2.1",
"jsondiffpatch": "^0.6.0",
"konva": "^9.3.14",
"konva": "^9.3.6",
"lodash-es": "^4.17.21",
"nanostores": "^0.11.2",
"nanostores": "^0.10.3",
"new-github-issue-url": "^1.0.0",
"overlayscrollbars": "^2.10.0",
"overlayscrollbars": "^2.7.3",
"overlayscrollbars-react": "^0.5.6",
"query-string": "^9.1.0",
"query-string": "^9.0.0",
"react": "^18.3.1",
"react-colorful": "^5.6.1",
"react-dom": "^18.3.1",
"react-dropzone": "^14.2.3",
"react-error-boundary": "^4.0.13",
"react-hook-form": "^7.52.2",
"react-hook-form": "^7.51.4",
"react-hotkeys-hook": "4.5.0",
"react-i18next": "^14.1.3",
"react-icons": "^5.2.1",
"react-i18next": "^14.1.1",
"react-icons": "^5.2.0",
"react-konva": "^18.2.10",
"react-redux": "9.1.2",
"react-resizable-panels": "^2.0.23",
"react-resizable-panels": "^2.0.19",
"react-select": "5.8.0",
"react-use": "^17.5.1",
"react-virtuoso": "^4.9.0",
"reactflow": "^11.11.4",
"react-use": "^17.5.0",
"react-virtuoso": "^4.7.10",
"redux-dynamic-middlewares": "^2.2.0",
"redux-remember": "^5.1.0",
"redux-undo": "^1.1.0",
"rfdc": "^1.4.1",
"rfdc": "^1.3.1",
"roarr": "^7.21.1",
"serialize-error": "^11.0.3",
"socket.io-client": "^4.7.5",
"use-debounce": "^10.0.2",
"use-debounce": "^10.0.0",
"use-device-pixel-ratio": "^1.1.2",
"use-image": "^1.1.1",
"uuid": "^10.0.0",
"zod": "^3.23.8",
"zod-validation-error": "^3.3.1"
"uuid": "^9.0.1",
"zod": "^3.23.6",
"zod-validation-error": "^3.2.0"
},
"peerDependencies": {
"@chakra-ui/react": "^2.8.2",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"ts-toolbelt": "^9.6.0"
@ -117,38 +118,38 @@
"devDependencies": {
"@invoke-ai/eslint-config-react": "^0.0.14",
"@invoke-ai/prettier-config-react": "^0.0.7",
"@storybook/addon-essentials": "^8.2.8",
"@storybook/addon-interactions": "^8.2.8",
"@storybook/addon-links": "^8.2.8",
"@storybook/addon-storysource": "^8.2.8",
"@storybook/manager-api": "^8.2.8",
"@storybook/react": "^8.2.8",
"@storybook/react-vite": "^8.2.8",
"@storybook/theming": "^8.2.8",
"@storybook/addon-essentials": "^8.0.10",
"@storybook/addon-interactions": "^8.0.10",
"@storybook/addon-links": "^8.0.10",
"@storybook/addon-storysource": "^8.0.10",
"@storybook/manager-api": "^8.0.10",
"@storybook/react": "^8.0.10",
"@storybook/react-vite": "^8.0.10",
"@storybook/theming": "^8.0.10",
"@types/dateformat": "^5.0.2",
"@types/lodash-es": "^4.17.12",
"@types/node": "^20.14.15",
"@types/react": "^18.3.3",
"@types/node": "^20.12.10",
"@types/react": "^18.3.1",
"@types/react-dom": "^18.3.0",
"@types/uuid": "^10.0.0",
"@vitejs/plugin-react-swc": "^3.7.0",
"@types/uuid": "^9.0.8",
"@vitejs/plugin-react-swc": "^3.6.0",
"@vitest/coverage-v8": "^1.5.0",
"@vitest/ui": "^1.5.0",
"concurrently": "^8.2.2",
"dpdm": "^3.14.0",
"eslint": "^8.57.0",
"eslint-plugin-i18next": "^6.0.9",
"eslint-plugin-i18next": "^6.0.3",
"eslint-plugin-path": "^1.3.0",
"knip": "^5.27.2",
"knip": "^5.12.3",
"openapi-types": "^12.1.3",
"openapi-typescript": "^7.3.0",
"prettier": "^3.3.3",
"openapi-typescript": "^6.7.5",
"prettier": "^3.2.5",
"rollup-plugin-visualizer": "^5.12.0",
"storybook": "^8.2.8",
"storybook": "^8.0.10",
"ts-toolbelt": "^9.6.0",
"tsafe": "^1.7.2",
"typescript": "^5.5.4",
"vite": "^5.4.0",
"tsafe": "^1.6.6",
"typescript": "^5.4.5",
"vite": "^5.2.11",
"vite-plugin-css-injected-by-js": "^3.5.1",
"vite-plugin-dts": "^3.9.1",
"vite-plugin-eslint": "^1.8.1",

File diff suppressed because it is too large Load Diff

View File

@ -696,8 +696,6 @@
"availableModels": "Available Models",
"baseModel": "Base Model",
"cancel": "Cancel",
"clipEmbed": "CLIP Embed",
"clipVision": "CLIP Vision",
"config": "Config",
"convert": "Convert",
"convertingModelBegin": "Converting Model. Please wait.",
@ -785,16 +783,13 @@
"settings": "Settings",
"simpleModelPlaceholder": "URL or path to a local file or diffusers folder",
"source": "Source",
"spandrelImageToImage": "Image to Image (Spandrel)",
"starterModels": "Starter Models",
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
"syncModels": "Sync Models",
"textualInversions": "Textual Inversions",
"triggerPhrases": "Trigger Phrases",
"loraTriggerPhrases": "LoRA Trigger Phrases",
"mainModelTriggerPhrases": "Main Model Trigger Phrases",
"typePhraseHere": "Type phrase here",
"t5Encoder": "T5 Encoder",
"upcastAttention": "Upcast Attention",
"uploadImage": "Upload Image",
"urlOrLocalPath": "URL or Local Path",
@ -1146,8 +1141,6 @@
"imageSavingFailed": "Image Saving Failed",
"imageUploaded": "Image Uploaded",
"imageUploadFailed": "Image Upload Failed",
"importFailed": "Import Failed",
"importSuccessful": "Import Successful",
"invalidUpload": "Invalid Upload",
"loadedWithWarnings": "Workflow Loaded with Warnings",
"maskSavedAssets": "Mask Saved to Assets",
@ -1680,7 +1673,6 @@
"layers_other": "Layers"
},
"upscaling": {
"upscale": "Upscale",
"creativity": "Creativity",
"exceedsMaxSize": "Upscale settings exceed max size limit",
"exceedsMaxSizeDetails": "Max upscale limit is {{maxUpscaleDimension}}x{{maxUpscaleDimension}} pixels. Please try a smaller image or decrease your scale selection.",
@ -1697,53 +1689,6 @@
"missingUpscaleModel": "Missing upscale model",
"missingTileControlNetModel": "No valid tile ControlNet models installed"
},
"stylePresets": {
"active": "Active",
"choosePromptTemplate": "Choose Prompt Template",
"clearTemplateSelection": "Clear Template Selection",
"copyTemplate": "Copy Template",
"createPromptTemplate": "Create Prompt Template",
"defaultTemplates": "Default Templates",
"deleteImage": "Delete Image",
"deleteTemplate": "Delete Template",
"deleteTemplate2": "Are you sure you want to delete this template? This cannot be undone.",
"exportPromptTemplates": "Export My Prompt Templates (CSV)",
"editTemplate": "Edit Template",
"exportDownloaded": "Export Downloaded",
"exportFailed": "Unable to generate and download CSV",
"flatten": "Flatten selected template into current prompt",
"importTemplates": "Import Prompt Templates (CSV/JSON)",
"acceptedColumnsKeys": "Accepted columns/keys:",
"nameColumn": "'name'",
"positivePromptColumn": "'prompt' or 'positive_prompt'",
"negativePromptColumn": "'negative_prompt'",
"insertPlaceholder": "Insert placeholder",
"myTemplates": "My Templates",
"name": "Name",
"negativePrompt": "Negative Prompt",
"noTemplates": "No templates",
"noMatchingTemplates": "No matching templates",
"promptTemplatesDesc1": "Prompt templates add text to the prompts you write in the prompt box.",
"promptTemplatesDesc2": "Use the placeholder string <Pre>{{placeholder}}</Pre> to specify where your prompt should be included in the template.",
"promptTemplatesDesc3": "If you omit the placeholder, the template will be appended to the end of your prompt.",
"positivePrompt": "Positive Prompt",
"preview": "Preview",
"private": "Private",
"promptTemplateCleared": "Prompt Template Cleared",
"searchByName": "Search by name",
"shared": "Shared",
"sharedTemplates": "Shared Templates",
"templateActions": "Template Actions",
"templateDeleted": "Prompt template deleted",
"toggleViewMode": "Toggle View Mode",
"type": "Type",
"unableToDeleteTemplate": "Unable to delete prompt template",
"updatePromptTemplate": "Update Prompt Template",
"uploadImage": "Upload Image",
"useForTemplate": "Use For Prompt Template",
"viewList": "View Template List",
"viewModeTooltip": "This is how your prompt will look with your currently selected template. To edit your prompt, click anywhere in the text box."
},
"upsell": {
"inviteTeammates": "Invite Teammates",
"professional": "Professional",

View File

@ -90,7 +90,7 @@
"disabled": "Disabilitato",
"comparingDesc": "Confronta due immagini",
"comparing": "Confronta",
"dontShowMeThese": "Non mostrare più"
"dontShowMeThese": "Non mostrarmi questi"
},
"gallery": {
"galleryImageSize": "Dimensione dell'immagine",
@ -701,9 +701,7 @@
"baseModelChanged": "Modello base modificato",
"sessionRef": "Sessione: {{sessionId}}",
"somethingWentWrong": "Qualcosa è andato storto",
"outOfMemoryErrorDesc": "Le impostazioni della generazione attuale superano la capacità del sistema. Modifica le impostazioni e riprova.",
"importFailed": "Importazione non riuscita",
"importSuccessful": "Importazione riuscita"
"outOfMemoryErrorDesc": "Le impostazioni della generazione attuale superano la capacità del sistema. Modifica le impostazioni e riprova."
},
"tooltip": {
"feature": {
@ -929,7 +927,7 @@
"missingInvocationTemplate": "Modello di invocazione mancante",
"missingFieldTemplate": "Modello di campo mancante",
"singleFieldType": "{{name}} (Singola)",
"imageAccessError": "Impossibile trovare l'immagine {{image_name}}, ripristino ai valori predefiniti",
"imageAccessError": "Impossibile trovare l'immagine {{image_name}}, ripristino delle impostazioni predefinite",
"boardAccessError": "Impossibile trovare la bacheca {{board_id}}, ripristino ai valori predefiniti",
"modelAccessError": "Impossibile trovare il modello {{key}}, ripristino ai valori predefiniti"
},
@ -1528,7 +1526,7 @@
},
"upscaleModel": {
"paragraphs": [
"Il modello di ampliamento (Upscale), scala l'immagine alle dimensioni di uscita prima di aggiungere i dettagli. È possibile utilizzare qualsiasi modello di ampliamento supportato, ma alcuni sono specializzati per diversi tipi di immagini, come foto o disegni al tratto."
"Il modello di ampliamento ridimensiona l'immagine alle dimensioni di uscita prima che vengano aggiunti i dettagli. È possibile utilizzare qualsiasi modello di ampliamento supportato, ma alcuni sono specializzati per diversi tipi di immagini, come foto o disegni al tratto."
],
"heading": "Modello di ampliamento"
},
@ -1737,58 +1735,12 @@
"missingUpscaleModel": "Modello per lampliamento mancante",
"missingTileControlNetModel": "Nessun modello ControlNet Tile valido installato",
"postProcessingModel": "Modello di post-elaborazione",
"postProcessingMissingModelWarning": "Visita <LinkComponent>Gestione modelli</LinkComponent> per installare un modello di post-elaborazione (da immagine a immagine).",
"exceedsMaxSize": "Le impostazioni di ampliamento superano il limite massimo delle dimensioni",
"exceedsMaxSizeDetails": "Il limite massimo di ampliamento è {{maxUpscaleDimension}}x{{maxUpscaleDimension}} pixel. Prova un'immagine più piccola o diminuisci la scala selezionata."
"postProcessingMissingModelWarning": "Visita <LinkComponent>Gestione modelli</LinkComponent> per installare un modello di post-elaborazione (da immagine a immagine)."
},
"upsell": {
"inviteTeammates": "Invita collaboratori",
"shareAccess": "Condividi l'accesso",
"professional": "Professionale",
"professionalUpsell": "Disponibile nell'edizione Professional di Invoke. Fai clic qui o visita invoke.com/pricing per ulteriori dettagli."
},
"stylePresets": {
"active": "Attivo",
"choosePromptTemplate": "Scegli un modello di prompt",
"clearTemplateSelection": "Cancella selezione modello",
"copyTemplate": "Copia modello",
"createPromptTemplate": "Crea modello di prompt",
"defaultTemplates": "Modelli predefiniti",
"deleteImage": "Elimina immagine",
"deleteTemplate": "Elimina modello",
"editTemplate": "Modifica modello",
"flatten": "Unisci il modello selezionato al prompt corrente",
"insertPlaceholder": "Inserisci segnaposto",
"myTemplates": "I miei modelli",
"name": "Nome",
"negativePrompt": "Prompt Negativo",
"noMatchingTemplates": "Nessun modello corrispondente",
"promptTemplatesDesc1": "I modelli di prompt aggiungono testo ai prompt che scrivi nelle caselle dei prompt.",
"promptTemplatesDesc3": "Se si omette il segnaposto, il modello verrà aggiunto alla fine del prompt.",
"positivePrompt": "Prompt Positivo",
"preview": "Anteprima",
"private": "Privato",
"searchByName": "Cerca per nome",
"shared": "Condiviso",
"sharedTemplates": "Modelli condivisi",
"templateDeleted": "Modello di prompt eliminato",
"toggleViewMode": "Attiva/disattiva visualizzazione",
"uploadImage": "Carica immagine",
"useForTemplate": "Usa per modello di prompt",
"viewList": "Visualizza l'elenco dei modelli",
"viewModeTooltip": "Ecco come apparirà il tuo prompt con il modello attualmente selezionato. Per modificare il tuo prompt, clicca in un punto qualsiasi della casella di testo.",
"deleteTemplate2": "Vuoi davvero eliminare questo modello? Questa operazione non può essere annullata.",
"unableToDeleteTemplate": "Impossibile eliminare il modello di prompt",
"updatePromptTemplate": "Aggiorna il modello di prompt",
"type": "Tipo",
"promptTemplatesDesc2": "Utilizza la stringa segnaposto <Pre>{{placeholder}}</Pre> per specificare dove inserire il tuo prompt nel modello.",
"importTemplates": "Importa modelli di prompt (CSV/JSON)",
"exportDownloaded": "Esportazione completata",
"exportFailed": "Impossibile generare e scaricare il file CSV",
"exportPromptTemplates": "Esporta i miei modelli di prompt (CSV)",
"positivePromptColumn": "'prompt' o 'positive_prompt'",
"noTemplates": "Nessun modello",
"acceptedColumnsKeys": "Colonne/chiavi accettate:",
"templateActions": "Azioni modello"
}
}

View File

@ -91,8 +91,7 @@
"enabled": "Включено",
"disabled": "Отключено",
"comparingDesc": "Сравнение двух изображений",
"comparing": "Сравнение",
"dontShowMeThese": "Не показывай мне это"
"comparing": "Сравнение"
},
"gallery": {
"galleryImageSize": "Размер изображений",
@ -154,11 +153,7 @@
"showArchivedBoards": "Показать архивированные доски",
"searchImages": "Поиск по метаданным",
"displayBoardSearch": "Отобразить поиск досок",
"displaySearch": "Отобразить поиск",
"exitBoardSearch": "Выйти из поиска досок",
"go": "Перейти",
"exitSearch": "Выйти из поиска",
"jump": "Пыгнуть"
"displaySearch": "Отобразить поиск"
},
"hotkeys": {
"keyboardShortcuts": "Горячие клавиши",
@ -381,10 +376,6 @@
"toggleViewer": {
"title": "Переключить просмотр изображений",
"desc": "Переключение между средством просмотра изображений и рабочей областью для текущей вкладки."
},
"postProcess": {
"desc": "Обработайте текущее изображение с помощью выбранной модели постобработки",
"title": "Обработать изображение"
}
},
"modelManager": {
@ -598,10 +589,7 @@
"infillColorValue": "Цвет заливки",
"globalSettings": "Глобальные настройки",
"globalNegativePromptPlaceholder": "Глобальный негативный запрос",
"globalPositivePromptPlaceholder": "Глобальный запрос",
"postProcessing": "Постобработка (Shift + U)",
"processImage": "Обработка изображения",
"sendToUpscale": "Отправить на увеличение"
"globalPositivePromptPlaceholder": "Глобальный запрос"
},
"settings": {
"models": "Модели",
@ -635,9 +623,7 @@
"intermediatesCleared_many": "Очищено {{count}} промежуточных",
"clearIntermediatesDesc1": "Очистка промежуточных элементов приведет к сбросу состояния Canvas и ControlNet.",
"intermediatesClearedFailed": "Проблема очистки промежуточных",
"reloadingIn": "Перезагрузка через",
"informationalPopoversDisabled": "Информационные всплывающие окна отключены",
"informationalPopoversDisabledDesc": "Информационные всплывающие окна были отключены. Включите их в Настройках."
"reloadingIn": "Перезагрузка через"
},
"toast": {
"uploadFailed": "Загрузка не удалась",
@ -708,9 +694,7 @@
"sessionRef": "Сессия: {{sessionId}}",
"outOfMemoryError": "Ошибка нехватки памяти",
"outOfMemoryErrorDesc": "Ваши текущие настройки генерации превышают возможности системы. Пожалуйста, измените настройки и повторите попытку.",
"somethingWentWrong": "Что-то пошло не так",
"importFailed": "Импорт неудачен",
"importSuccessful": "Импорт успешен"
"somethingWentWrong": "Что-то пошло не так"
},
"tooltip": {
"feature": {
@ -1033,8 +1017,7 @@
"composition": "Только композиция",
"hed": "HED",
"beginEndStepPercentShort": "Начало/конец %",
"setControlImageDimensionsForce": "Скопируйте размер в Ш/В (игнорируйте модель)",
"depthAnythingSmallV2": "Small V2"
"setControlImageDimensionsForce": "Скопируйте размер в Ш/В (игнорируйте модель)"
},
"boards": {
"autoAddBoard": "Авто добавление Доски",
@ -1059,7 +1042,7 @@
"downloadBoard": "Скачать доску",
"deleteBoard": "Удалить доску",
"deleteBoardAndImages": "Удалить доску и изображения",
"deletedBoardsCannotbeRestored": "Удаленные доски не могут быть восстановлены. Выбор «Удалить только доску» переведет изображения в состояние без категории.",
"deletedBoardsCannotbeRestored": "Удаленные доски не подлежат восстановлению",
"assetsWithCount_one": "{{count}} ассет",
"assetsWithCount_few": "{{count}} ассета",
"assetsWithCount_many": "{{count}} ассетов",
@ -1074,11 +1057,7 @@
"boards": "Доски",
"addPrivateBoard": "Добавить личную доску",
"private": "Личные доски",
"shared": "Общие доски",
"hideBoards": "Скрыть доски",
"viewBoards": "Просмотреть доски",
"noBoards": "Нет досок {{boardType}}",
"deletedPrivateBoardsCannotbeRestored": "Удаленные доски не могут быть восстановлены. Выбор «Удалить только доску» переведет изображения в приватное состояние без категории для создателя изображения."
"shared": "Общие доски"
},
"dynamicPrompts": {
"seedBehaviour": {
@ -1438,30 +1417,6 @@
"paragraphs": [
"Метод, с помощью которого применяется текущий IP-адаптер."
]
},
"structure": {
"paragraphs": [
"Структура контролирует, насколько точно выходное изображение будет соответствовать макету оригинала. Низкая структура допускает значительные изменения, в то время как высокая структура строго сохраняет исходную композицию и макет."
],
"heading": "Структура"
},
"scale": {
"paragraphs": [
"Масштаб управляет размером выходного изображения и основывается на кратном разрешении входного изображения. Например, при увеличении в 2 раза изображения 1024x1024 на выходе получится 2048 x 2048."
],
"heading": "Масштаб"
},
"creativity": {
"paragraphs": [
"Креативность контролирует степень свободы, предоставляемой модели при добавлении деталей. При низкой креативности модель остается близкой к оригинальному изображению, в то время как высокая креативность позволяет вносить больше изменений. При использовании подсказки высокая креативность увеличивает влияние подсказки."
],
"heading": "Креативность"
},
"upscaleModel": {
"heading": "Модель увеличения",
"paragraphs": [
"Модель увеличения масштаба масштабирует изображение до выходного размера перед добавлением деталей. Можно использовать любую поддерживаемую модель масштабирования, но некоторые из них специализированы для различных видов изображений, например фотографий или линейных рисунков."
]
}
},
"metadata": {
@ -1738,78 +1693,7 @@
"canvasTab": "$t(ui.tabs.canvas) $t(common.tab)",
"queueTab": "$t(ui.tabs.queue) $t(common.tab)",
"modelsTab": "$t(ui.tabs.models) $t(common.tab)",
"queue": "Очередь",
"upscaling": "Увеличение",
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)"
"queue": "Очередь"
}
},
"upscaling": {
"exceedsMaxSize": "Параметры масштабирования превышают максимальный размер",
"exceedsMaxSizeDetails": "Максимальный предел масштабирования составляет {{maxUpscaleDimension}}x{{maxUpscaleDimension}} пикселей. Пожалуйста, попробуйте использовать меньшее изображение или уменьшите масштаб.",
"structure": "Структура",
"missingTileControlNetModel": "Не установлены подходящие модели ControlNet",
"missingUpscaleInitialImage": "Отсутствует увеличиваемое изображение",
"missingUpscaleModel": "Отсутствует увеличивающая модель",
"creativity": "Креативность",
"upscaleModel": "Модель увеличения",
"scale": "Масштаб",
"mainModelDesc": "Основная модель (архитектура SD1.5 или SDXL)",
"upscaleModelDesc": "Модель увеличения (img2img)",
"postProcessingModel": "Модель постобработки",
"tileControlNetModelDesc": "Модель ControlNet для выбранной архитектуры основной модели",
"missingModelsWarning": "Зайдите в <LinkComponent>Менеджер моделей</LinkComponent> чтоб установить необходимые модели:",
"postProcessingMissingModelWarning": "Посетите <LinkComponent>Менеджер моделей</LinkComponent>, чтобы установить модель постобработки (img2img)."
},
"stylePresets": {
"noMatchingTemplates": "Нет подходящих шаблонов",
"promptTemplatesDesc1": "Шаблоны подсказок добавляют текст к подсказкам, которые вы пишете в окне подсказок.",
"sharedTemplates": "Общие шаблоны",
"templateDeleted": "Шаблон запроса удален",
"toggleViewMode": "Переключить режим просмотра",
"type": "Тип",
"unableToDeleteTemplate": "Не получилось удалить шаблон запроса",
"viewModeTooltip": "Вот как будет выглядеть ваш запрос с выбранным шаблоном. Чтобы его отредактировать, щелкните в любом месте текстового поля.",
"viewList": "Просмотреть список шаблонов",
"active": "Активно",
"choosePromptTemplate": "Выберите шаблон запроса",
"defaultTemplates": "Стандартные шаблоны",
"deleteImage": "Удалить изображение",
"deleteTemplate": "Удалить шаблон",
"deleteTemplate2": "Вы уверены, что хотите удалить этот шаблон? Это нельзя отменить.",
"editTemplate": "Редактировать шаблон",
"exportPromptTemplates": "Экспорт моих шаблонов запроса (CSV)",
"exportDownloaded": "Экспорт скачан",
"exportFailed": "Невозможно сгенерировать и загрузить CSV",
"flatten": "Объединить выбранный шаблон с текущим запросом",
"acceptedColumnsKeys": "Принимаемые столбцы/ключи:",
"positivePromptColumn": "'prompt' или 'positive_prompt'",
"insertPlaceholder": "Вставить заполнитель",
"name": "Имя",
"negativePrompt": "Негативный запрос",
"promptTemplatesDesc3": "Если вы не используете заполнитель, шаблон будет добавлен в конец запроса.",
"positivePrompt": "Позитивный запрос",
"preview": "Предпросмотр",
"private": "Приватный",
"templateActions": "Действия с шаблоном",
"updatePromptTemplate": "Обновить шаблон запроса",
"uploadImage": "Загрузить изображение",
"useForTemplate": "Использовать для шаблона запроса",
"clearTemplateSelection": "Очистить выбор шаблона",
"copyTemplate": "Копировать шаблон",
"createPromptTemplate": "Создать шаблон запроса",
"importTemplates": "Импортировать шаблоны запроса (CSV/JSON)",
"nameColumn": "'name'",
"negativePromptColumn": "'negative_prompt'",
"myTemplates": "Мои шаблоны",
"noTemplates": "Нет шаблонов",
"promptTemplatesDesc2": "Используйте строку-заполнитель <Pre>{{placeholder}}</Pre>, чтобы указать место, куда должен быть включен ваш запрос в шаблоне.",
"searchByName": "Поиск по имени",
"shared": "Общий"
},
"upsell": {
"inviteTeammates": "Пригласите членов команды",
"professional": "Профессионал",
"professionalUpsell": "Доступно в профессиональной версии Invoke. Нажмите здесь или посетите invoke.com/pricing для получения более подробной информации.",
"shareAccess": "Поделиться доступом"
}
}

View File

@ -493,8 +493,7 @@
"defaultSettingsSaved": "默认设置已保存",
"huggingFacePlaceholder": "所有者或模型名称",
"huggingFaceRepoID": "HuggingFace仓库ID",
"loraTriggerPhrases": "LoRA 触发词",
"ipAdapters": "IP适配器"
"loraTriggerPhrases": "LoRA 触发词"
},
"parameters": {
"images": "图像",
@ -1703,9 +1702,7 @@
"upscaleModelDesc": "图像放大(图像到图像转换)模型",
"postProcessingMissingModelWarning": "请访问 <LinkComponent>模型管理器</LinkComponent>来安装一个后处理(图像到图像转换)模型.",
"missingModelsWarning": "请访问<LinkComponent>模型管理器</LinkComponent> 安装所需的模型:",
"mainModelDesc": "主模型SD1.5或SDXL架构",
"exceedsMaxSize": "放大设置超出了最大尺寸限制",
"exceedsMaxSizeDetails": "最大放大限制是 {{maxUpscaleDimension}}x{{maxUpscaleDimension}} 像素.请尝试一个较小的图像或减少您的缩放选择."
"mainModelDesc": "主模型SD1.5或SDXL架构"
},
"upsell": {
"inviteTeammates": "邀请团队成员",

View File

@ -1,40 +1,26 @@
/* eslint-disable no-console */
import fs from 'node:fs';
import openapiTS, { astToString } from 'openapi-typescript';
import ts from 'typescript';
import openapiTS from 'openapi-typescript';
const OPENAPI_URL = 'http://127.0.0.1:9090/openapi.json';
const OUTPUT_FILE = 'src/services/api/schema.ts';
async function generateTypes(schema) {
process.stdout.write(`Generating types ${OUTPUT_FILE}...`);
// Use https://ts-ast-viewer.com to figure out how to create these AST nodes - define a type and use the bottom-left pane's output
// `Blob` type
const BLOB = ts.factory.createTypeReferenceNode(ts.factory.createIdentifier('Blob'));
// `null` type
const NULL = ts.factory.createLiteralTypeNode(ts.factory.createNull());
// `Record<string, unknown>` type
const RECORD_STRING_UNKNOWN = ts.factory.createTypeReferenceNode(ts.factory.createIdentifier('Record'), [
ts.factory.createKeywordTypeNode(ts.SyntaxKind.StringKeyword),
ts.factory.createKeywordTypeNode(ts.SyntaxKind.UnknownKeyword),
]);
const types = await openapiTS(schema, {
exportType: true,
transform: (schemaObject) => {
if ('format' in schemaObject && schemaObject.format === 'binary') {
return schemaObject.nullable ? ts.factory.createUnionTypeNode([BLOB, NULL]) : BLOB;
return schemaObject.nullable ? 'Blob | null' : 'Blob';
}
if (schemaObject.title === 'MetadataField') {
// This is `Record<string, never>` by default, but it actually accepts any a dict of any valid JSON value.
return RECORD_STRING_UNKNOWN;
return 'Record<string, unknown>';
}
},
defaultNonNullable: false,
});
fs.writeFileSync(OUTPUT_FILE, astToString(types));
fs.writeFileSync(OUTPUT_FILE, types);
process.stdout.write(`\nOK!\r\n`);
}

View File

@ -13,14 +13,11 @@ import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardMo
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
import { StylePresetModal } from 'features/stylePresets/components/StylePresetForm/StylePresetModal';
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
import { configChanged } from 'features/system/store/configSlice';
import { languageSelector } from 'features/system/store/systemSelectors';
import InvokeTabs from 'features/ui/components/InvokeTabs';
import type { InvokeTabName } from 'features/ui/store/tabMap';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { useGetAndLoadLibraryWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadLibraryWorkflow';
import { AnimatePresence } from 'framer-motion';
import i18n from 'i18n';
import { size } from 'lodash-es';
@ -39,18 +36,10 @@ interface Props {
imageName: string;
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
};
selectedWorkflowId?: string;
selectedStylePresetId?: string;
destination?: InvokeTabName | undefined;
}
const App = ({
config = DEFAULT_CONFIG,
selectedImage,
selectedWorkflowId,
selectedStylePresetId,
destination,
}: Props) => {
const App = ({ config = DEFAULT_CONFIG, selectedImage, destination }: Props) => {
const language = useAppSelector(languageSelector);
const logger = useLogger('system');
const dispatch = useAppDispatch();
@ -81,20 +70,6 @@ const App = ({
}
}, [dispatch, config, logger]);
const { getAndLoadWorkflow } = useGetAndLoadLibraryWorkflow();
useEffect(() => {
if (selectedWorkflowId) {
getAndLoadWorkflow(selectedWorkflowId);
}
}, [selectedWorkflowId, getAndLoadWorkflow]);
useEffect(() => {
if (selectedStylePresetId) {
dispatch(activeStylePresetIdChanged(selectedStylePresetId));
}
}, [dispatch, selectedStylePresetId]);
useEffect(() => {
if (destination) {
dispatch(setActiveTab(destination));
@ -129,7 +104,6 @@ const App = ({
<DeleteImageModal />
<ChangeBoardModal />
<DynamicPromptsModal />
<StylePresetModal />
<PreselectedImage selectedImage={selectedImage} />
</ErrorBoundary>
);

View File

@ -44,8 +44,6 @@ interface Props extends PropsWithChildren {
imageName: string;
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
};
selectedWorkflowId?: string;
selectedStylePresetId?: string;
destination?: InvokeTabName;
customStarUi?: CustomStarUi;
socketOptions?: Partial<ManagerOptions & SocketOptions>;
@ -66,8 +64,6 @@ const InvokeAIUI = ({
projectUrl,
queueId,
selectedImage,
selectedWorkflowId,
selectedStylePresetId,
destination,
customStarUi,
socketOptions,
@ -225,13 +221,7 @@ const InvokeAIUI = ({
<React.Suspense fallback={<Loading />}>
<ThemeLocaleProvider>
<AppDndContext>
<App
config={config}
selectedImage={selectedImage}
selectedWorkflowId={selectedWorkflowId}
selectedStylePresetId={selectedStylePresetId}
destination={destination}
/>
<App config={config} selectedImage={selectedImage} destination={destination} />
</AppDndContext>
</ThemeLocaleProvider>
</React.Suspense>

View File

@ -11,9 +11,6 @@ import {
promptsChanged,
} from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt';
import { getPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
import { stylePresetsApi } from 'services/api/endpoints/stylePresets';
import { utilitiesApi } from 'services/api/endpoints/utilities';
import { socketConnected } from 'services/events/actions';
@ -22,11 +19,7 @@ const matcher = isAnyOf(
combinatorialToggled,
maxPromptsChanged,
maxPromptsReset,
socketConnected,
activeStylePresetIdChanged,
stylePresetsApi.endpoints.deleteStylePreset.matchFulfilled,
stylePresetsApi.endpoints.updateStylePreset.matchFulfilled,
stylePresetsApi.endpoints.listStylePresets.matchFulfilled
socketConnected
);
export const addDynamicPromptsListener = (startAppListening: AppStartListening) => {
@ -35,7 +28,7 @@ export const addDynamicPromptsListener = (startAppListening: AppStartListening)
effect: async (action, { dispatch, getState, cancelActiveListeners, delay }) => {
cancelActiveListeners();
const state = getState();
const { positivePrompt } = getPresetModifiedPrompts(state);
const { positivePrompt } = state.controlLayers.present;
const { maxPrompts } = state.dynamicPrompts;
if (state.config.disabledFeatures.includes('dynamicPrompting')) {

View File

@ -28,7 +28,6 @@ import { generationPersistConfig, generationSlice } from 'features/parameters/st
import { upscalePersistConfig, upscaleSlice } from 'features/parameters/store/upscaleSlice';
import { queueSlice } from 'features/queue/store/queueSlice';
import { sdxlPersistConfig, sdxlSlice } from 'features/sdxl/store/sdxlSlice';
import { stylePresetPersistConfig, stylePresetSlice } from 'features/stylePresets/store/stylePresetSlice';
import { configSlice } from 'features/system/store/configSlice';
import { systemPersistConfig, systemSlice } from 'features/system/store/systemSlice';
import { uiPersistConfig, uiSlice } from 'features/ui/store/uiSlice';
@ -70,7 +69,6 @@ const allReducers = {
[workflowSettingsSlice.name]: workflowSettingsSlice.reducer,
[api.reducerPath]: api.reducer,
[upscaleSlice.name]: upscaleSlice.reducer,
[stylePresetSlice.name]: stylePresetSlice.reducer,
};
const rootReducer = combineReducers(allReducers);
@ -116,7 +114,6 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
[controlLayersPersistConfig.name]: controlLayersPersistConfig,
[workflowSettingsPersistConfig.name]: workflowSettingsPersistConfig,
[upscalePersistConfig.name]: upscalePersistConfig,
[stylePresetPersistConfig.name]: stylePresetPersistConfig,
};
const unserialize: UnserializeFunction = (data, key) => {
@ -167,8 +164,8 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
reducer: rememberedRootReducer,
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({
serializableCheck: import.meta.env.MODE === 'development',
immutableCheck: import.meta.env.MODE === 'development',
serializableCheck: false,
immutableCheck: false,
})
.concat(api.middleware)
.concat(dynamicMiddlewares)

View File

@ -71,7 +71,6 @@ export type AppConfig = {
*/
maxUpscaleDimension?: number;
allowPrivateBoards: boolean;
allowPrivateStylePresets: boolean;
disabledTabs: InvokeTabName[];
disabledFeatures: AppFeature[];
disabledSDFeatures: SDFeature[];

View File

@ -47,7 +47,6 @@ export const IAINoContentFallback = memo((props: IAINoImageFallbackProps) => {
userSelect: 'none',
opacity: 0.7,
color: 'base.500',
fontSize: 'md',
...sx,
}),
[sx]
@ -56,7 +55,11 @@ export const IAINoContentFallback = memo((props: IAINoImageFallbackProps) => {
return (
<Flex sx={styles} {...rest}>
{icon && <Icon as={icon} boxSize={boxSize} opacity={0.7} />}
{props.label && <Text textAlign="center">{props.label}</Text>}
{props.label && (
<Text textAlign="center" fontSize="md">
{props.label}
</Text>
)}
</Flex>
);
});

View File

@ -1,4 +1,4 @@
import { convertImageUrlToBlob } from 'common/util/convertImageUrlToBlob';
import { useImageUrlToBlob } from 'common/hooks/useImageUrlToBlob';
import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard';
import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react';
@ -6,6 +6,7 @@ import { useTranslation } from 'react-i18next';
export const useCopyImageToClipboard = () => {
const { t } = useTranslation();
const imageUrlToBlob = useImageUrlToBlob();
const isClipboardAPIAvailable = useMemo(() => {
return Boolean(navigator.clipboard) && Boolean(window.ClipboardItem);
@ -22,7 +23,7 @@ export const useCopyImageToClipboard = () => {
});
}
try {
const blob = await convertImageUrlToBlob(image_url);
const blob = await imageUrlToBlob(image_url);
if (!blob) {
throw new Error('Unable to create Blob');
@ -44,7 +45,7 @@ export const useCopyImageToClipboard = () => {
});
}
},
[isClipboardAPIAvailable, t]
[imageUrlToBlob, isClipboardAPIAvailable, t]
);
return { isClipboardAPIAvailable, copyImageToClipboard };

View File

@ -0,0 +1,40 @@
import { $authToken } from 'app/store/nanostores/authToken';
import { useCallback } from 'react';
/**
* Converts an image URL to a Blob by creating an <img /> element, drawing it to canvas
* and then converting the canvas to a Blob.
*
* @returns A function that takes a URL and returns a Promise that resolves with a Blob
*/
export const useImageUrlToBlob = () => {
const imageUrlToBlob = useCallback(
async (url: string) =>
new Promise<Blob | null>((resolve) => {
const img = new Image();
img.onload = () => {
const canvas = document.createElement('canvas');
canvas.width = img.width;
canvas.height = img.height;
const context = canvas.getContext('2d');
if (!context) {
return;
}
context.drawImage(img, 0, 0);
resolve(
new Promise<Blob | null>((resolve) => {
canvas.toBlob(function (blob) {
resolve(blob);
}, 'image/png');
})
);
};
img.crossOrigin = $authToken.get() ? 'use-credentials' : 'anonymous';
img.src = url;
}),
[]
);
return imageUrlToBlob;
};

View File

@ -1,4 +1,5 @@
import { useStore } from '@nanostores/react';
import { getConnectedEdges } from '@xyflow/react';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import {
@ -22,7 +23,6 @@ import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import i18n from 'i18next';
import { forEach, upperFirst } from 'lodash-es';
import { useMemo } from 'react';
import { getConnectedEdges } from 'reactflow';
const LAYER_TYPE_TO_TKEY: Record<Layer['type'], string> = {
initial_image_layer: 'controlLayers.globalInitialImage',

View File

@ -1,39 +0,0 @@
import { $authToken } from 'app/store/nanostores/authToken';
/**
* Converts an image URL to a Blob by creating an <img /> element, drawing it to canvas
* and then converting the canvas to a Blob.
*
* @returns A function that takes a URL and returns a Promise that resolves with a Blob
*/
export const convertImageUrlToBlob = async (url: string) =>
new Promise<Blob | null>((resolve, reject) => {
const img = new Image();
img.onload = () => {
const canvas = document.createElement('canvas');
canvas.width = img.width;
canvas.height = img.height;
const context = canvas.getContext('2d');
if (!context) {
reject(new Error('Failed to get canvas context'));
return;
}
context.drawImage(img, 0, 0);
canvas.toBlob((blob) => {
if (blob) {
resolve(blob);
} else {
reject(new Error('Failed to convert image to blob'));
}
}, 'image/png');
};
img.onerror = () => {
reject(new Error('Image failed to load. The URL may be invalid or the object may not exist.'));
};
img.crossOrigin = $authToken.get() ? 'use-credentials' : 'anonymous';
img.src = url;
});

View File

@ -66,7 +66,7 @@ export const Gallery = () => {
<Flex flexDirection="column" alignItems="center" justifyContent="space-between" h="full" w="full" pt={1}>
<Tabs index={galleryView === 'images' ? 0 : 1} variant="enclosed" display="flex" flexDir="column" w="full">
<TabList gap={2} fontSize="sm" borderColor="base.800" alignItems="center" w="full">
<Text fontSize="sm" fontWeight="semibold" noOfLines={1} px="2" wordBreak="break-all">
<Text fontSize="sm" fontWeight="semibold" noOfLines={1} px="2">
{boardName}
</Text>
<Spacer />

View File

@ -30,7 +30,6 @@ import {
PiFlowArrowBold,
PiFoldersBold,
PiImagesBold,
PiPaintBrushBold,
PiPlantBold,
PiQuotesBold,
PiShareFatBold,
@ -56,17 +55,8 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const { downloadImage } = useDownloadImage();
const templates = useStore($templates);
const {
recallAll,
remix,
recallSeed,
recallPrompts,
hasMetadata,
hasSeed,
hasPrompts,
isLoadingMetadata,
createAsPreset,
} = useImageActions(imageDTO?.image_name);
const { recallAll, remix, recallSeed, recallPrompts, hasMetadata, hasSeed, hasPrompts, isLoadingMetadata } =
useImageActions(imageDTO?.image_name);
const { getAndLoadEmbeddedWorkflow, getAndLoadEmbeddedWorkflowResult } = useGetAndLoadEmbeddedWorkflow({});
@ -192,13 +182,6 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
>
{t('parameters.useAll')}
</MenuItem>
<MenuItem
icon={isLoadingMetadata ? <SpinnerIcon /> : <PiPaintBrushBold />}
onClickCapture={createAsPreset}
isDisabled={isLoadingMetadata || !hasPrompts}
>
{t('stylePresets.useForTemplate')}
</MenuItem>
<MenuDivider />
<MenuItem icon={<PiShareFatBold />} onClickCapture={handleSendToImageToImage} id="send-to-img2img">
{t('parameters.sendToImg2Img')}

View File

@ -60,6 +60,7 @@ const ImageGalleryContent = () => {
<GalleryHeader />
<Flex alignItems="center" justifyContent="space-between" w="full">
<Button
w={112}
size="sm"
variant="ghost"
onClick={handleToggleBoardPanel}

View File

@ -1,25 +1,15 @@
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAppSelector } from 'app/store/storeHooks';
import { handlers, parseAndRecallAllMetadata, parseAndRecallPrompts } from 'features/metadata/util/handlers';
import { $stylePresetModalState } from 'features/stylePresets/store/stylePresetModal';
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
import { toast } from 'features/toast/toast';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
export const useImageActions = (image_name?: string) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const activeTabName = useAppSelector(activeTabNameSelector);
const activeStylePresetId = useAppSelector((s) => s.stylePreset.activeStylePresetId);
const { metadata, isLoading: isLoadingMetadata } = useDebouncedMetadata(image_name);
const [hasMetadata, setHasMetadata] = useState(false);
const [hasSeed, setHasSeed] = useState(false);
const [hasPrompts, setHasPrompts] = useState(false);
const { data: imageDTO } = useGetImageDTOQuery(image_name ?? skipToken);
useEffect(() => {
const parseMetadata = async () => {
@ -52,26 +42,14 @@ export const useImageActions = (image_name?: string) => {
parseMetadata();
}, [metadata]);
const clearStylePreset = useCallback(() => {
if (activeStylePresetId) {
dispatch(activeStylePresetIdChanged(null));
toast({
status: 'info',
title: t('stylePresets.promptTemplateCleared'),
});
}
}, [dispatch, activeStylePresetId, t]);
const recallAll = useCallback(() => {
parseAndRecallAllMetadata(metadata, activeTabName === 'generation');
clearStylePreset();
}, [activeTabName, metadata, clearStylePreset]);
}, [activeTabName, metadata]);
const remix = useCallback(() => {
// Recalls all metadata parameters except seed
parseAndRecallAllMetadata(metadata, activeTabName === 'generation', ['seed']);
clearStylePreset();
}, [activeTabName, metadata, clearStylePreset]);
}, [activeTabName, metadata]);
const recallSeed = useCallback(() => {
handlers.seed.parse(metadata).then((seed) => {
@ -81,48 +59,7 @@ export const useImageActions = (image_name?: string) => {
const recallPrompts = useCallback(() => {
parseAndRecallPrompts(metadata);
clearStylePreset();
}, [metadata, clearStylePreset]);
}, [metadata]);
const createAsPreset = useCallback(async () => {
if (image_name && metadata && imageDTO) {
let positivePrompt;
let negativePrompt;
try {
positivePrompt = await handlers.positivePrompt.parse(metadata);
} catch (error) {
positivePrompt = '';
}
try {
negativePrompt = await handlers.negativePrompt.parse(metadata);
} catch (error) {
negativePrompt = '';
}
$stylePresetModalState.set({
prefilledFormData: {
name: '',
positivePrompt,
negativePrompt,
imageUrl: imageDTO.image_url,
type: 'user',
},
updatingStylePresetId: null,
isModalOpen: true,
});
}
}, [image_name, metadata, imageDTO]);
return {
recallAll,
remix,
recallSeed,
recallPrompts,
hasMetadata,
hasSeed,
hasPrompts,
isLoadingMetadata,
createAsPreset,
};
return { recallAll, remix, recallSeed, recallPrompts, hasMetadata, hasSeed, hasPrompts, isLoadingMetadata };
};

Some files were not shown because too many files have changed in this diff Show More