Merge branch 'main' into tiled-upscaling-graph

This commit is contained in:
skunkworxdark 2023-12-16 19:09:28 +00:00
commit b8354bd1a4
63 changed files with 2531 additions and 693 deletions

View File

@ -42,6 +42,21 @@ Please provide steps on how to test changes, any hardware or
software specifications as well as any other pertinent information.
-->
## Merge Plan
<!--
A merge plan describes how this PR should be handled after it is approved.
Example merge plans:
- "This PR can be merged when approved"
- "This must be squash-merged when approved"
- "DO NOT MERGE - I will rebase and tidy commits before merging"
- "#dev-chat on discord needs to be advised of this change when it is merged"
A merge plan is particularly important for large PRs or PRs that touch the
database in any way.
-->
## Added/updated tests?
- [ ] Yes

View File

@ -21,16 +21,16 @@ jobs:
if: github.event.pull_request.draft == false
runs-on: ubuntu-22.04
steps:
- name: Setup Node 20
- name: Setup Node 18
uses: actions/setup-node@v4
with:
node-version: '20'
node-version: '18'
- name: Checkout
uses: actions/checkout@v4
- name: Setup pnpm
uses: pnpm/action-setup@v2
with:
version: 8
version: '8.12.1'
- name: Install dependencies
run: 'pnpm install --prefer-frozen-lockfile'
- name: Typescript

View File

@ -1,13 +1,15 @@
name: PyPI Release
on:
push:
paths:
- 'invokeai/version/invokeai_version.py'
workflow_dispatch:
inputs:
publish_package:
description: 'Publish build on PyPi? [true/false]'
required: true
default: 'false'
jobs:
release:
build-and-release:
if: github.repository == 'invoke-ai/InvokeAI'
runs-on: ubuntu-22.04
env:
@ -15,19 +17,43 @@ jobs:
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
TWINE_NON_INTERACTIVE: 1
steps:
- name: checkout sources
uses: actions/checkout@v3
- name: Checkout
uses: actions/checkout@v4
- name: install deps
- name: Setup Node 18
uses: actions/setup-node@v4
with:
node-version: '18'
- name: Setup pnpm
uses: pnpm/action-setup@v2
with:
version: '8.12.1'
- name: Install frontend dependencies
run: pnpm install --prefer-frozen-lockfile
working-directory: invokeai/frontend/web
- name: Build frontend
run: pnpm run build
working-directory: invokeai/frontend/web
- name: Install python dependencies
run: pip install --upgrade build twine
- name: build package
- name: Build python package
run: python3 -m build
- name: check distribution
- name: Upload build as workflow artifact
uses: actions/upload-artifact@v4
with:
name: dist
path: dist
- name: Check distribution
run: twine check dist/*
- name: check PyPI versions
- name: Check PyPI versions
if: github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')
run: |
pip install --upgrade requests
@ -36,6 +62,6 @@ jobs:
EXISTS=scripts.pypi_helper.local_on_pypi(); \
print(f'PACKAGE_EXISTS={EXISTS}')" >> $GITHUB_ENV
- name: upload package
if: env.PACKAGE_EXISTS == 'False' && env.TWINE_PASSWORD != ''
- name: Publish build on PyPi
if: env.PACKAGE_EXISTS == 'False' && env.TWINE_PASSWORD != '' && github.event.inputs.publish_package == 'true'
run: twine upload dist/*

View File

@ -293,6 +293,19 @@ manager, please follow these steps:
## Developer Install
!!! warning
InvokeAI uses a SQLite database. By running on `main`, you accept responsibility for your database. This
means making regular backups (especially before pulling) and/or fixing it yourself in the event that a
PR introduces a schema change.
If you don't need persistent backend storage, you can use an ephemeral in-memory database by setting
`use_memory_db: true` under `Path:` in your `invokeai.yaml` file.
If this is untenable, you should run the application via the official installer or a manual install of the
python package from pypi. These releases will not break your database.
If you have an interest in how InvokeAI works, or you would like to
add features or bugfixes, you are encouraged to install the source
code for InvokeAI. For this to work, you will need to install the
@ -388,3 +401,5 @@ environment variable INVOKEAI_ROOT to point to the installation directory.
Note that if you run into problems with the Conda installation, the InvokeAI
staff will **not** be able to help you out. Caveat Emptor!
[dev-chat]: https://discord.com/channels/1020123559063990373/1049495067846524939

View File

@ -0,0 +1,10 @@
document.addEventListener("DOMContentLoaded", function () {
var script = document.createElement("script");
script.src = "https://widget.kapa.ai/kapa-widget.bundle.js";
script.setAttribute("data-website-id", "b5973bb1-476b-451e-8cf4-98de86745a10");
script.setAttribute("data-project-name", "Invoke.AI");
script.setAttribute("data-project-color", "#11213C");
script.setAttribute("data-project-logo", "https://avatars.githubusercontent.com/u/113954515?s=280&v=4");
script.async = true;
document.head.appendChild(script);
});

View File

@ -91,9 +91,11 @@ rm -rf InvokeAI-Installer
# copy content
mkdir InvokeAI-Installer
for f in templates lib *.txt *.reg; do
for f in templates *.txt *.reg; do
cp -r ${f} InvokeAI-Installer/
done
mkdir InvokeAI-Installer/lib
cp lib/*.py InvokeAI-Installer/lib
# Move the wheel
mv dist/*.whl InvokeAI-Installer/lib/
@ -111,6 +113,6 @@ cp WinLongPathsEnabled.reg InvokeAI-Installer/
zip -r InvokeAI-installer-$VERSION.zip InvokeAI-Installer
# clean up
rm -rf InvokeAI-Installer tmp dist
rm -rf InvokeAI-Installer tmp dist ../invokeai/frontend/web/dist/
exit 0

View File

@ -244,9 +244,9 @@ class InvokeAiInstance:
"numpy~=1.24.0", # choose versions that won't be uninstalled during phase 2
"urllib3~=1.26.0",
"requests~=2.28.0",
"torch==2.1.0",
"torch==2.1.1",
"torchmetrics==0.11.4",
"torchvision>=0.14.1",
"torchvision>=0.16.1",
"--force-reinstall",
"--find-links" if find_links is not None else None,
find_links,

View File

@ -2,6 +2,7 @@
from logging import Logger
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
@ -30,7 +31,6 @@ from ..services.session_processor.session_processor_default import DefaultSessio
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
from ..services.shared.default_graphs import create_system_graphs
from ..services.shared.graph import GraphExecutionState, LibraryGraph
from ..services.shared.sqlite.sqlite_database import SqliteDatabase
from ..services.urls.urls_default import LocalUrlService
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from .events import FastAPIEventService
@ -67,8 +67,9 @@ class ApiDependencies:
logger.debug(f"Internet connectivity is {config.internet_available}")
output_folder = config.output_path
image_files = DiskImageFileStorage(f"{output_folder}/images")
db = SqliteDatabase(config, logger)
db = init_db(config=config, logger=logger, image_files=image_files)
configuration = config
logger = logger
@ -80,7 +81,6 @@ class ApiDependencies:
events = FastAPIEventService(event_handler_id)
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
graph_library = SqliteItemStorage[LibraryGraph](db=db, table_name="graphs")
image_files = DiskImageFileStorage(f"{output_folder}/images")
image_records = SqliteImageRecordStorage(db=db)
images = ImageService()
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)

View File

@ -45,6 +45,9 @@ async def list_model_records(
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"),
model_format: Optional[str] = Query(
default=None, description="Exact match on the format of the model (e.g. 'diffusers')"
),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records
@ -52,10 +55,14 @@ async def list_model_records(
if base_models:
for base_model in base_models:
found_models.extend(
record_store.search_by_attr(base_model=base_model, model_type=model_type, model_name=model_name)
record_store.search_by_attr(
base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format
)
)
else:
found_models.extend(record_store.search_by_attr(model_type=model_type, model_name=model_name))
found_models.extend(
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
)
return ModelsList(models=found_models)

View File

@ -13,7 +13,15 @@ from invokeai.app.shared.fields import FieldDescriptions
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker
from .baseinvocation import BaseInvocation, Input, InputField, InvocationContext, WithMetadata, invocation
from .baseinvocation import (
BaseInvocation,
Classification,
Input,
InputField,
InvocationContext,
WithMetadata,
invocation,
)
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0")
@ -421,6 +429,64 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata):
)
@invocation(
"unsharp_mask",
title="Unsharp Mask",
tags=["image", "unsharp_mask"],
category="image",
version="1.2.0",
classification=Classification.Beta,
)
class UnsharpMaskInvocation(BaseInvocation, WithMetadata):
"""Applies an unsharp mask filter to an image"""
image: ImageField = InputField(description="The image to use")
radius: float = InputField(gt=0, description="Unsharp mask radius", default=2)
strength: float = InputField(ge=0, description="Unsharp mask strength", default=50)
def pil_from_array(self, arr):
return Image.fromarray((arr * 255).astype("uint8"))
def array_from_pil(self, img):
return numpy.array(img) / 255
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
mode = image.mode
alpha_channel = image.getchannel("A") if mode == "RGBA" else None
image = image.convert("RGB")
image_blurred = self.array_from_pil(image.filter(ImageFilter.GaussianBlur(radius=self.radius)))
image = self.array_from_pil(image)
image += (image - image_blurred) * (self.strength / 100.0)
image = numpy.clip(image, 0, 1)
image = self.pil_from_array(image)
image = image.convert(mode)
# Make the image RGBA if we had a source alpha channel
if alpha_channel is not None:
image.putalpha(alpha_channel)
image_dto = context.services.images.create(
image=image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=context.workflow,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image.width,
height=image.height,
)
PIL_RESAMPLING_MODES = Literal[
"nearest",
"box",

View File

@ -38,7 +38,14 @@ class CalculateImageTilesOutput(BaseInvocationOutput):
tiles: list[Tile] = OutputField(description="The tiles coordinates that cover a particular image shape.")
@invocation("calculate_image_tiles", title="Calculate Image Tiles", tags=["tiles"], category="tiles", version="1.0.0")
@invocation(
"calculate_image_tiles",
title="Calculate Image Tiles",
tags=["tiles"],
category="tiles",
version="1.0.0",
classification=Classification.Beta,
)
class CalculateImageTilesInvocation(BaseInvocation):
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""

View File

@ -20,63 +20,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
self._conn = db.conn
self._cursor = self._conn.cursor()
try:
self._lock.acquire()
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `board_images` junction table."""
# Create the `board_images` junction table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS board_images (
board_id TEXT NOT NULL,
image_name TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
-- enforce one-to-many relationship between boards and images using PK
-- (we can extend this to many-to-many later)
PRIMARY KEY (image_name),
FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE,
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
);
"""
)
# Add index for board id
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id);
"""
)
# Add index for board id, sorted by created_at
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at
AFTER UPDATE
ON board_images FOR EACH ROW
BEGIN
UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE board_id = old.board_id AND image_name = old.image_name;
END;
"""
)
def add_image_to_board(
self,
board_id: str,

View File

@ -28,52 +28,6 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
self._conn = db.conn
self._cursor = self._conn.cursor()
try:
self._lock.acquire()
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `boards` table and `board_images` junction table."""
# Create the `boards` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS boards (
board_id TEXT NOT NULL PRIMARY KEY,
board_name TEXT NOT NULL,
cover_image_name TEXT,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL
);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
AFTER UPDATE
ON boards FOR EACH ROW
BEGIN
UPDATE boards SET updated_at = current_timestamp
WHERE board_id = old.board_id;
END;
"""
)
def delete(self, board_id: str) -> None:
try:
self._lock.acquire()

View File

@ -32,101 +32,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
self._conn = db.conn
self._cursor = self._conn.cursor()
try:
self._lock.acquire()
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `images` table."""
# Create the `images` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS images (
image_name TEXT NOT NULL PRIMARY KEY,
-- This is an enum in python, unrestricted string here for flexibility
image_origin TEXT NOT NULL,
-- This is an enum in python, unrestricted string here for flexibility
image_category TEXT NOT NULL,
width INTEGER NOT NULL,
height INTEGER NOT NULL,
session_id TEXT,
node_id TEXT,
metadata TEXT,
is_intermediate BOOLEAN DEFAULT FALSE,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME
);
"""
)
self._cursor.execute("PRAGMA table_info(images)")
columns = [column[1] for column in self._cursor.fetchall()]
if "starred" not in columns:
self._cursor.execute(
"""--sql
ALTER TABLE images ADD COLUMN starred BOOLEAN DEFAULT FALSE;
"""
)
# Create the `images` table indices.
self._cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_starred ON images(starred);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
AFTER UPDATE
ON images FOR EACH ROW
BEGIN
UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE image_name = old.image_name;
END;
"""
)
self._cursor.execute("PRAGMA table_info(images)")
columns = [column[1] for column in self._cursor.fetchall()]
if "has_workflow" not in columns:
self._cursor.execute(
"""--sql
ALTER TABLE images
ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;
"""
)
def get(self, image_name: str) -> ImageRecord:
try:
self._lock.acquire()

View File

@ -5,7 +5,6 @@ from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union
from fastapi import Body
from pydantic import BaseModel, Field, field_validator
from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
@ -112,17 +111,7 @@ class URLModelSource(StringLikeSource):
return str(self.url)
# Body() is being applied here rather than Field() because otherwise FastAPI will
# refuse to generate a schema. Relevant links:
#
# "Model Manager Refactor Phase 1 - SQL-based config storage
# https://github.com/invoke-ai/InvokeAI/pull/5039#discussion_r1389752119 (comment)
# Param: xyz can only be a request body, using Body() when using discriminated unions
# https://github.com/tiangolo/fastapi/discussions/9761
# Body parameter cannot be a pydantic union anymore sinve v0.95
# https://github.com/tiangolo/fastapi/discussions/9287
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Body(discriminator="type")]
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
class ModelInstallJob(BaseModel):

View File

@ -7,10 +7,7 @@ from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional, Union
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
# should match the InvokeAI version when this is first released.
CONFIG_FILE_VERSION = "3.2.0"
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType
class DuplicateModelException(Exception):
@ -32,12 +29,6 @@ class ConfigFileVersionMismatchException(Exception):
class ModelRecordServiceBase(ABC):
"""Abstract base class for storage and retrieval of model configs."""
@property
@abstractmethod
def version(self) -> str:
"""Return the config file/database schema version."""
pass
@abstractmethod
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
"""
@ -115,6 +106,7 @@ class ModelRecordServiceBase(ABC):
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
model_format: Optional[ModelFormat] = None,
) -> List[AnyModelConfig]:
"""
Return models matching name, base and/or type.
@ -122,6 +114,7 @@ class ModelRecordServiceBase(ABC):
:param model_name: Filter by name of model (optional)
:param base_model: Filter by base model (optional)
:param model_type: Filter by type of model (optional)
:param model_format: Filter by model format (e.g. "diffusers") (optional)
If none of the optional filters are passed, will return all
models in the database.

View File

@ -49,12 +49,12 @@ from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelConfigFactory,
ModelFormat,
ModelType,
)
from ..shared.sqlite.sqlite_database import SqliteDatabase
from .model_records_base import (
CONFIG_FILE_VERSION,
DuplicateModelException,
ModelRecordServiceBase,
UnknownModelException,
@ -78,86 +78,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db = db
self._cursor = self._db.conn.cursor()
with self._db.lock:
# Enable foreign keys
self._db.conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._db.conn.commit()
assert (
str(self.version) == CONFIG_FILE_VERSION
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
def _create_tables(self) -> None:
"""Create sqlite3 tables."""
# model_config table breaks out the fields that are common to all config objects
# and puts class-specific ones in a serialized json object
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_config (
id TEXT NOT NULL PRIMARY KEY,
-- The next 3 fields are enums in python, unrestricted string here
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
original_hash TEXT, -- could be null
-- Serialized JSON representation of the whole config object,
-- which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- unique constraint on combo of name, base and type
UNIQUE(name, base, type)
);
"""
)
# metadata table
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_manager_metadata (
metadata_key TEXT NOT NULL PRIMARY KEY,
metadata_value TEXT NOT NULL
);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
AFTER UPDATE
ON model_config FOR EACH ROW
BEGIN
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
)
# Add indexes for searchable fields
for stmt in [
"CREATE INDEX IF NOT EXISTS base_index ON model_config(base);",
"CREATE INDEX IF NOT EXISTS type_index ON model_config(type);",
"CREATE INDEX IF NOT EXISTS name_index ON model_config(name);",
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);",
]:
self._cursor.execute(stmt)
# Add our version to the metadata table
self._cursor.execute(
"""--sql
INSERT OR IGNORE into model_manager_metadata (
metadata_key,
metadata_value
)
VALUES (?,?);
""",
("version", CONFIG_FILE_VERSION),
)
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
"""
Add a model to the database.
@ -207,22 +127,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
return self.get_model(key)
@property
def version(self) -> str:
"""Return the version of the database schema."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT metadata_value FROM model_manager_metadata
WHERE metadata_key=?;
""",
("version",),
)
rows = self._cursor.fetchone()
if not rows:
raise KeyError("Models database does not have metadata key 'version'")
return rows[0]
def del_model(self, key: str) -> None:
"""
Delete a model.
@ -322,6 +226,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
model_format: Optional[ModelFormat] = None,
) -> List[AnyModelConfig]:
"""
Return models matching name, base and/or type.
@ -329,6 +234,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
:param model_name: Filter by name of model (optional)
:param base_model: Filter by base model (optional)
:param model_type: Filter by type of model (optional)
:param model_format: Filter by model format (e.g. "diffusers") (optional)
If none of the optional filters are passed, will return all
models in the database.
@ -345,6 +251,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
if model_type:
where_clause.append("type=?")
bindings.append(model_type)
if model_format:
where_clause.append("format=?")
bindings.append(model_format)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
with self._db.lock:
self._cursor.execute(

View File

@ -50,7 +50,6 @@ class SqliteSessionQueue(SessionQueueBase):
self.__lock = db.lock
self.__conn = db.conn
self.__cursor = self.__conn.cursor()
self._create_tables()
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
return event[1]["event"] in match_in
@ -98,123 +97,6 @@ class SqliteSessionQueue(SessionQueueBase):
except SessionQueueItemNotFoundError:
return
def _create_tables(self) -> None:
"""Creates the session queue tables, indicies, and triggers"""
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS session_queue (
item_id INTEGER PRIMARY KEY AUTOINCREMENT, -- used for ordering, cursor pagination
batch_id TEXT NOT NULL, -- identifier of the batch this queue item belongs to
queue_id TEXT NOT NULL, -- identifier of the queue this queue item belongs to
session_id TEXT NOT NULL UNIQUE, -- duplicated data from the session column, for ease of access
field_values TEXT, -- NULL if no values are associated with this queue item
session TEXT NOT NULL, -- the session to be executed
status TEXT NOT NULL DEFAULT 'pending', -- the status of the queue item, one of 'pending', 'in_progress', 'completed', 'failed', 'canceled'
priority INTEGER NOT NULL DEFAULT 0, -- the priority, higher is more important
error TEXT, -- any errors associated with this queue item
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger
started_at DATETIME, -- updated via trigger
completed_at DATETIME -- updated via trigger, completed items are cleaned up on application startup
-- Ideally this is a FK, but graph_executions uses INSERT OR REPLACE, and REPLACE triggers the ON DELETE CASCADE...
-- FOREIGN KEY (session_id) REFERENCES graph_executions (id) ON DELETE CASCADE
);
"""
)
self.__cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_item_id ON session_queue(item_id);
"""
)
self.__cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_session_id ON session_queue(session_id);
"""
)
self.__cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_session_queue_batch_id ON session_queue(batch_id);
"""
)
self.__cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_session_queue_created_priority ON session_queue(priority);
"""
)
self.__cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_session_queue_created_status ON session_queue(status);
"""
)
self.__cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_completed_at
AFTER UPDATE OF status ON session_queue
FOR EACH ROW
WHEN
NEW.status = 'completed'
OR NEW.status = 'failed'
OR NEW.status = 'canceled'
BEGIN
UPDATE session_queue
SET completed_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = NEW.item_id;
END;
"""
)
self.__cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_started_at
AFTER UPDATE OF status ON session_queue
FOR EACH ROW
WHEN
NEW.status = 'in_progress'
BEGIN
UPDATE session_queue
SET started_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = NEW.item_id;
END;
"""
)
self.__cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_updated_at
AFTER UPDATE
ON session_queue FOR EACH ROW
BEGIN
UPDATE session_queue
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = old.item_id;
END;
"""
)
self.__cursor.execute("PRAGMA table_info(session_queue)")
columns = [column[1] for column in self.__cursor.fetchall()]
if "workflow" not in columns:
self.__cursor.execute(
"""--sql
ALTER TABLE session_queue ADD COLUMN workflow TEXT;
"""
)
self.__conn.commit()
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
def _set_in_progress_to_canceled(self) -> None:
"""
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.

View File

@ -3,45 +3,65 @@ import threading
from logging import Logger
from pathlib import Path
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
class SqliteDatabase:
def __init__(self, config: InvokeAIAppConfig, logger: Logger):
self._logger = logger
self._config = config
"""
Manages a connection to an SQLite database.
if self._config.use_memory_db:
self.db_path = sqlite_memory
logger.info("Using in-memory database")
:param db_path: Path to the database file. If None, an in-memory database is used.
:param logger: Logger to use for logging.
:param verbose: Whether to log SQL statements. Provides `logger.debug` as the SQLite trace callback.
This is a light wrapper around the `sqlite3` module, providing a few conveniences:
- The database file is written to disk if it does not exist.
- Foreign key constraints are enabled by default.
- The connection is configured to use the `sqlite3.Row` row factory.
In addition to the constructor args, the instance provides the following attributes and methods:
- `conn`: A `sqlite3.Connection` object. Note that the connection must never be closed if the database is in-memory.
- `lock`: A shared re-entrant lock, used to approximate thread safety.
- `clean()`: Runs the SQL `VACUUM;` command and reports on the freed space.
"""
def __init__(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
"""Initializes the database. This is used internally by the class constructor."""
self.logger = logger
self.db_path = db_path
self.verbose = verbose
if not self.db_path:
logger.info("Initializing in-memory database")
else:
db_path = self._config.db_path
db_path.parent.mkdir(parents=True, exist_ok=True)
self.db_path = str(db_path)
self._logger.info(f"Using database at {self.db_path}")
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self.logger.info(f"Initializing database at {self.db_path}")
self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
self.conn = sqlite3.connect(database=self.db_path or sqlite_memory, check_same_thread=False)
self.lock = threading.RLock()
self.conn.row_factory = sqlite3.Row
if self._config.log_sql:
self.conn.set_trace_callback(self._logger.debug)
if self.verbose:
self.conn.set_trace_callback(self.logger.debug)
self.conn.execute("PRAGMA foreign_keys = ON;")
def clean(self) -> None:
"""
Cleans the database by running the VACUUM command, reporting on the freed space.
"""
# No need to clean in-memory database
if not self.db_path:
return
with self.lock:
try:
if self.db_path == sqlite_memory:
return
initial_db_size = Path(self.db_path).stat().st_size
self.conn.execute("VACUUM;")
self.conn.commit()
final_db_size = Path(self.db_path).stat().st_size
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
if freed_space_in_mb > 0:
self._logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
self.logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
except Exception as e:
self._logger.error(f"Error cleaning database: {e}")
self.logger.error(f"Error cleaning database: {e}")
raise

View File

@ -0,0 +1,32 @@
from logging import Logger
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import build_migration_1
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import build_migration_2
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileStorageBase) -> SqliteDatabase:
"""
Initializes the SQLite database.
:param config: The app config
:param logger: The logger
:param image_files: The image files service (used by migration 2)
This function:
- Instantiates a :class:`SqliteDatabase`
- Instantiates a :class:`SqliteMigrator` and registers all migrations
- Runs all migrations
"""
db_path = None if config.use_memory_db else config.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
migrator = SqliteMigrator(db=db)
migrator.register_migration(build_migration_1())
migrator.register_migration(build_migration_2(image_files=image_files, logger=logger))
migrator.run_migrations()
return db

View File

@ -0,0 +1,372 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration1Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
"""Migration callback for database version 1."""
self._create_board_images(cursor)
self._create_boards(cursor)
self._create_images(cursor)
self._create_model_config(cursor)
self._create_session_queue(cursor)
self._create_workflow_images(cursor)
self._create_workflows(cursor)
def _create_board_images(self, cursor: sqlite3.Cursor) -> None:
"""Creates the `board_images` table, indices and triggers."""
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS board_images (
board_id TEXT NOT NULL,
image_name TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
-- enforce one-to-many relationship between boards and images using PK
-- (we can extend this to many-to-many later)
PRIMARY KEY (image_name),
FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE,
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
);
"""
]
indices = [
"CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id);",
"CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at);",
]
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at
AFTER UPDATE
ON board_images FOR EACH ROW
BEGIN
UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE board_id = old.board_id AND image_name = old.image_name;
END;
"""
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_boards(self, cursor: sqlite3.Cursor) -> None:
"""Creates the `boards` table, indices and triggers."""
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS boards (
board_id TEXT NOT NULL PRIMARY KEY,
board_name TEXT NOT NULL,
cover_image_name TEXT,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL
);
"""
]
indices = ["CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at);"]
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
AFTER UPDATE
ON boards FOR EACH ROW
BEGIN
UPDATE boards SET updated_at = current_timestamp
WHERE board_id = old.board_id;
END;
"""
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_images(self, cursor: sqlite3.Cursor) -> None:
"""Creates the `images` table, indices and triggers. Adds the `starred` column."""
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS images (
image_name TEXT NOT NULL PRIMARY KEY,
-- This is an enum in python, unrestricted string here for flexibility
image_origin TEXT NOT NULL,
-- This is an enum in python, unrestricted string here for flexibility
image_category TEXT NOT NULL,
width INTEGER NOT NULL,
height INTEGER NOT NULL,
session_id TEXT,
node_id TEXT,
metadata TEXT,
is_intermediate BOOLEAN DEFAULT FALSE,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME
);
"""
]
indices = [
"CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);",
"CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);",
"CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);",
"CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);",
]
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
AFTER UPDATE
ON images FOR EACH ROW
BEGIN
UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE image_name = old.image_name;
END;
"""
]
# Add the 'starred' column to `images` if it doesn't exist
cursor.execute("PRAGMA table_info(images)")
columns = [column[1] for column in cursor.fetchall()]
if "starred" not in columns:
tables.append("ALTER TABLE images ADD COLUMN starred BOOLEAN DEFAULT FALSE;")
indices.append("CREATE INDEX IF NOT EXISTS idx_images_starred ON images(starred);")
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_model_config(self, cursor: sqlite3.Cursor) -> None:
"""Creates the `model_config` table, `model_manager_metadata` table, indices and triggers."""
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS model_config (
id TEXT NOT NULL PRIMARY KEY,
-- The next 3 fields are enums in python, unrestricted string here
base TEXT NOT NULL,
type TEXT NOT NULL,
name TEXT NOT NULL,
path TEXT NOT NULL,
original_hash TEXT, -- could be null
-- Serialized JSON representation of the whole config object,
-- which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- unique constraint on combo of name, base and type
UNIQUE(name, base, type)
);
""",
"""--sql
CREATE TABLE IF NOT EXISTS model_manager_metadata (
metadata_key TEXT NOT NULL PRIMARY KEY,
metadata_value TEXT NOT NULL
);
""",
]
# Add trigger for `updated_at`.
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
AFTER UPDATE
ON model_config FOR EACH ROW
BEGIN
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
]
# Add indexes for searchable fields
indices = [
"CREATE INDEX IF NOT EXISTS base_index ON model_config(base);",
"CREATE INDEX IF NOT EXISTS type_index ON model_config(type);",
"CREATE INDEX IF NOT EXISTS name_index ON model_config(name);",
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);",
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_session_queue(self, cursor: sqlite3.Cursor) -> None:
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS session_queue (
item_id INTEGER PRIMARY KEY AUTOINCREMENT, -- used for ordering, cursor pagination
batch_id TEXT NOT NULL, -- identifier of the batch this queue item belongs to
queue_id TEXT NOT NULL, -- identifier of the queue this queue item belongs to
session_id TEXT NOT NULL UNIQUE, -- duplicated data from the session column, for ease of access
field_values TEXT, -- NULL if no values are associated with this queue item
session TEXT NOT NULL, -- the session to be executed
status TEXT NOT NULL DEFAULT 'pending', -- the status of the queue item, one of 'pending', 'in_progress', 'completed', 'failed', 'canceled'
priority INTEGER NOT NULL DEFAULT 0, -- the priority, higher is more important
error TEXT, -- any errors associated with this queue item
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger
started_at DATETIME, -- updated via trigger
completed_at DATETIME -- updated via trigger, completed items are cleaned up on application startup
-- Ideally this is a FK, but graph_executions uses INSERT OR REPLACE, and REPLACE triggers the ON DELETE CASCADE...
-- FOREIGN KEY (session_id) REFERENCES graph_executions (id) ON DELETE CASCADE
);
"""
]
indices = [
"CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_item_id ON session_queue(item_id);",
"CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_session_id ON session_queue(session_id);",
"CREATE INDEX IF NOT EXISTS idx_session_queue_batch_id ON session_queue(batch_id);",
"CREATE INDEX IF NOT EXISTS idx_session_queue_created_priority ON session_queue(priority);",
"CREATE INDEX IF NOT EXISTS idx_session_queue_created_status ON session_queue(status);",
]
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_completed_at
AFTER UPDATE OF status ON session_queue
FOR EACH ROW
WHEN
NEW.status = 'completed'
OR NEW.status = 'failed'
OR NEW.status = 'canceled'
BEGIN
UPDATE session_queue
SET completed_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = NEW.item_id;
END;
""",
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_started_at
AFTER UPDATE OF status ON session_queue
FOR EACH ROW
WHEN
NEW.status = 'in_progress'
BEGIN
UPDATE session_queue
SET started_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = NEW.item_id;
END;
""",
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_updated_at
AFTER UPDATE
ON session_queue FOR EACH ROW
BEGIN
UPDATE session_queue
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = old.item_id;
END;
""",
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_workflow_images(self, cursor: sqlite3.Cursor) -> None:
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS workflow_images (
workflow_id TEXT NOT NULL,
image_name TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
-- enforce one-to-many relationship between workflows and images using PK
-- (we can extend this to many-to-many later)
PRIMARY KEY (image_name),
FOREIGN KEY (workflow_id) REFERENCES workflows (workflow_id) ON DELETE CASCADE,
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
);
"""
]
indices = [
"CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id ON workflow_images (workflow_id);",
"CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id_created_at ON workflow_images (workflow_id, created_at);",
]
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_workflow_images_updated_at
AFTER UPDATE
ON workflow_images FOR EACH ROW
BEGIN
UPDATE workflow_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = old.workflow_id AND image_name = old.image_name;
END;
"""
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_workflows(self, cursor: sqlite3.Cursor) -> None:
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS workflows (
workflow TEXT NOT NULL,
workflow_id TEXT GENERATED ALWAYS AS (json_extract(workflow, '$.id')) VIRTUAL NOT NULL UNIQUE, -- gets implicit index
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) -- updated via trigger
);
"""
]
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_workflows_updated_at
AFTER UPDATE
ON workflows FOR EACH ROW
BEGIN
UPDATE workflows
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = old.workflow_id;
END;
"""
]
for stmt in tables + triggers:
cursor.execute(stmt)
def build_migration_1() -> Migration:
"""
Builds the migration from database version 0 (init) to 1.
This migration represents the state of the database circa InvokeAI v3.4.0, which was the last
version to not use migrations to manage the database.
As such, this migration does include some ALTER statements, and the SQL statements are written
to be idempotent.
- Create `board_images` junction table
- Create `boards` table
- Create `images` table, add `starred` column
- Create `model_config` table
- Create `session_queue` table
- Create `workflow_images` junction table
- Create `workflows` table
"""
migration_1 = Migration(
from_version=0,
to_version=1,
callback=Migration1Callback(),
)
return migration_1

View File

@ -0,0 +1,198 @@
import sqlite3
from logging import Logger
from pydantic import ValidationError
from tqdm import tqdm
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.image_files.image_files_common import ImageFileNotFoundException
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
from invokeai.app.services.workflow_records.workflow_records_common import (
UnsafeWorkflowWithVersionValidator,
)
class Migration2Callback:
def __init__(self, image_files: ImageFileStorageBase, logger: Logger):
self._image_files = image_files
self._logger = logger
def __call__(self, cursor: sqlite3.Cursor):
self._add_images_has_workflow(cursor)
self._add_session_queue_workflow(cursor)
self._drop_old_workflow_tables(cursor)
self._add_workflow_library(cursor)
self._drop_model_manager_metadata(cursor)
self._recreate_model_config(cursor)
self._migrate_embedded_workflows(cursor)
def _add_images_has_workflow(self, cursor: sqlite3.Cursor) -> None:
"""Add the `has_workflow` column to `images` table."""
cursor.execute("PRAGMA table_info(images)")
columns = [column[1] for column in cursor.fetchall()]
if "has_workflow" not in columns:
cursor.execute("ALTER TABLE images ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;")
def _add_session_queue_workflow(self, cursor: sqlite3.Cursor) -> None:
"""Add the `workflow` column to `session_queue` table."""
cursor.execute("PRAGMA table_info(session_queue)")
columns = [column[1] for column in cursor.fetchall()]
if "workflow" not in columns:
cursor.execute("ALTER TABLE session_queue ADD COLUMN workflow TEXT;")
def _drop_old_workflow_tables(self, cursor: sqlite3.Cursor) -> None:
"""Drops the `workflows` and `workflow_images` tables."""
cursor.execute("DROP TABLE IF EXISTS workflow_images;")
cursor.execute("DROP TABLE IF EXISTS workflows;")
def _add_workflow_library(self, cursor: sqlite3.Cursor) -> None:
"""Adds the `workflow_library` table and drops the `workflows` and `workflow_images` tables."""
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS workflow_library (
workflow_id TEXT NOT NULL PRIMARY KEY,
workflow TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated manually when retrieving workflow
opened_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Generated columns, needed for indexing and searching
category TEXT GENERATED ALWAYS as (json_extract(workflow, '$.meta.category')) VIRTUAL NOT NULL,
name TEXT GENERATED ALWAYS as (json_extract(workflow, '$.name')) VIRTUAL NOT NULL,
description TEXT GENERATED ALWAYS as (json_extract(workflow, '$.description')) VIRTUAL NOT NULL
);
""",
]
indices = [
"CREATE INDEX IF NOT EXISTS idx_workflow_library_created_at ON workflow_library(created_at);",
"CREATE INDEX IF NOT EXISTS idx_workflow_library_updated_at ON workflow_library(updated_at);",
"CREATE INDEX IF NOT EXISTS idx_workflow_library_opened_at ON workflow_library(opened_at);",
"CREATE INDEX IF NOT EXISTS idx_workflow_library_category ON workflow_library(category);",
"CREATE INDEX IF NOT EXISTS idx_workflow_library_name ON workflow_library(name);",
"CREATE INDEX IF NOT EXISTS idx_workflow_library_description ON workflow_library(description);",
]
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_workflow_library_updated_at
AFTER UPDATE
ON workflow_library FOR EACH ROW
BEGIN
UPDATE workflow_library
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = old.workflow_id;
END;
"""
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _drop_model_manager_metadata(self, cursor: sqlite3.Cursor) -> None:
"""Drops the `model_manager_metadata` table."""
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
def _recreate_model_config(self, cursor: sqlite3.Cursor) -> None:
"""
Drops the `model_config` table, recreating it.
In 3.4.0, this table used explicit columns but was changed to use json_extract 3.5.0.
Because this table is not used in production, we are able to simply drop it and recreate it.
"""
cursor.execute("DROP TABLE IF EXISTS model_config;")
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_config (
id TEXT NOT NULL PRIMARY KEY,
-- The next 3 fields are enums in python, unrestricted string here
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
original_hash TEXT, -- could be null
-- Serialized JSON representation of the whole config object,
-- which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- unique constraint on combo of name, base and type
UNIQUE(name, base, type)
);
"""
)
def _migrate_embedded_workflows(self, cursor: sqlite3.Cursor) -> None:
"""
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in
the database now has a `has_workflow` column, indicating if an image has a workflow embedded.
This migrate callback checks each image for the presence of an embedded workflow, then updates its entry
in the database accordingly.
"""
# Get all image names
cursor.execute("SELECT image_name FROM images")
image_names: list[str] = [image[0] for image in cursor.fetchall()]
total_image_names = len(image_names)
if not total_image_names:
return
self._logger.info(f"Migrating workflows for {total_image_names} images")
# Migrate the images
to_migrate: list[tuple[bool, str]] = []
pbar = tqdm(image_names)
for idx, image_name in enumerate(pbar):
pbar.set_description(f"Checking image {idx + 1}/{total_image_names} for workflow")
try:
pil_image = self._image_files.get(image_name)
except ImageFileNotFoundException:
self._logger.warning(f"Image {image_name} not found, skipping")
continue
if "invokeai_workflow" in pil_image.info:
try:
UnsafeWorkflowWithVersionValidator.validate_json(pil_image.info.get("invokeai_workflow", ""))
except ValidationError:
self._logger.warning(f"Image {image_name} has invalid embedded workflow, skipping")
continue
to_migrate.append((True, image_name))
self._logger.info(f"Adding {len(to_migrate)} embedded workflows to database")
cursor.executemany("UPDATE images SET has_workflow = ? WHERE image_name = ?", to_migrate)
def build_migration_2(image_files: ImageFileStorageBase, logger: Logger) -> Migration:
"""
Builds the migration from database version 1 to 2.
Introduced in v3.5.0 for the new workflow library.
:param image_files: The image files service, used to check for embedded workflows
:param logger: The logger, used to log progress during embedded workflows handling
This migration does the following:
- Add `has_workflow` column to `images` table
- Add `workflow` column to `session_queue` table
- Drop `workflows` and `workflow_images` tables
- Add `workflow_library` table
- Drops the `model_manager_metadata` table
- Drops the `model_config` table, recreating it (at this point, there is no user data in this table)
- Populates the `has_workflow` column in the `images` table (requires `image_files` & `logger` dependencies)
"""
migration_2 = Migration(
from_version=1,
to_version=2,
callback=Migration2Callback(image_files=image_files, logger=logger),
)
return migration_2

View File

@ -0,0 +1,164 @@
import sqlite3
from typing import Optional, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field, model_validator
@runtime_checkable
class MigrateCallback(Protocol):
"""
A callback that performs a migration.
Migrate callbacks are provided an open cursor to the database. They should not commit their
transaction; this is handled by the migrator.
If the callback needs to access additional dependencies, will be provided to the callback at runtime.
See :class:`Migration` for an example.
"""
def __call__(self, cursor: sqlite3.Cursor) -> None:
...
class MigrationError(RuntimeError):
"""Raised when a migration fails."""
class MigrationVersionError(ValueError):
"""Raised when a migration version is invalid."""
class Migration(BaseModel):
"""
Represents a migration for a SQLite database.
:param from_version: The database version on which this migration may be run
:param to_version: The database version that results from this migration
:param migrate_callback: The callback to run to perform the migration
Migration callbacks will be provided an open cursor to the database. They should not commit their
transaction; this is handled by the migrator.
It is suggested to use a class to define the migration callback and a builder function to create
the :class:`Migration`. This allows the callback to be provided with additional dependencies and
keeps things tidy, as all migration logic is self-contained.
Example:
```py
# Define the migration callback class
class Migration1Callback:
# This migration needs a logger, so we define a class that accepts a logger in its constructor.
def __init__(self, image_files: ImageFileStorageBase) -> None:
self._image_files = ImageFileStorageBase
# This dunder method allows the instance of the class to be called like a function.
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._add_with_banana_column(cursor)
self._do_something_with_images(cursor)
def _add_with_banana_column(self, cursor: sqlite3.Cursor) -> None:
\"""Adds the with_banana column to the sushi table.\"""
# Execute SQL using the cursor, taking care to *not commit* a transaction
cursor.execute('ALTER TABLE sushi ADD COLUMN with_banana BOOLEAN DEFAULT TRUE;')
def _do_something_with_images(self, cursor: sqlite3.Cursor) -> None:
\"""Does something with the image files service.\"""
self._image_files.get(...)
# Define the migration builder function. This function creates an instance of the migration callback
# class and returns a Migration.
def build_migration_1(image_files: ImageFileStorageBase) -> Migration:
\"""Builds the migration from database version 0 to 1.
Requires the image files service to...
\"""
migration_1 = Migration(
from_version=0,
to_version=1,
migrate_callback=Migration1Callback(image_files=image_files),
)
return migration_1
# Register the migration after all dependencies have been initialized
db = SqliteDatabase(db_path, logger)
migrator = SqliteMigrator(db)
migrator.register_migration(build_migration_1(image_files))
migrator.run_migrations()
```
"""
from_version: int = Field(ge=0, strict=True, description="The database version on which this migration may be run")
to_version: int = Field(ge=1, strict=True, description="The database version that results from this migration")
callback: MigrateCallback = Field(description="The callback to run to perform the migration")
@model_validator(mode="after")
def validate_to_version(self) -> "Migration":
"""Validates that to_version is one greater than from_version."""
if self.to_version != self.from_version + 1:
raise MigrationVersionError("to_version must be one greater than from_version")
return self
def __hash__(self) -> int:
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.
return hash((self.from_version, self.to_version))
model_config = ConfigDict(arbitrary_types_allowed=True)
class MigrationSet:
"""
A set of Migrations. Performs validation during migration registration and provides utility methods.
Migrations should be registered with `register()`. Once all are registered, `validate_migration_chain()`
should be called to ensure that the migrations form a single chain of migrations from version 0 to the latest version.
"""
def __init__(self) -> None:
self._migrations: set[Migration] = set()
def register(self, migration: Migration) -> None:
"""Registers a migration."""
migration_from_already_registered = any(m.from_version == migration.from_version for m in self._migrations)
migration_to_already_registered = any(m.to_version == migration.to_version for m in self._migrations)
if migration_from_already_registered or migration_to_already_registered:
raise MigrationVersionError("Migration with from_version or to_version already registered")
self._migrations.add(migration)
def get(self, from_version: int) -> Optional[Migration]:
"""Gets the migration that may be run on the given database version."""
# register() ensures that there is only one migration with a given from_version, so this is safe.
return next((m for m in self._migrations if m.from_version == from_version), None)
def validate_migration_chain(self) -> None:
"""
Validates that the migrations form a single chain of migrations from version 0 to the latest version,
Raises a MigrationError if there is a problem.
"""
if self.count == 0:
return
if self.latest_version == 0:
return
next_migration = self.get(from_version=0)
if next_migration is None:
raise MigrationError("Migration chain is fragmented")
touched_count = 1
while next_migration is not None:
next_migration = self.get(next_migration.to_version)
if next_migration is not None:
touched_count += 1
if touched_count != self.count:
raise MigrationError("Migration chain is fragmented")
@property
def count(self) -> int:
"""The count of registered migrations."""
return len(self._migrations)
@property
def latest_version(self) -> int:
"""Gets latest to_version among registered migrations. Returns 0 if there are no migrations registered."""
if self.count == 0:
return 0
return sorted(self._migrations, key=lambda m: m.to_version)[-1].to_version

View File

@ -0,0 +1,130 @@
import sqlite3
from pathlib import Path
from typing import Optional
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration, MigrationError, MigrationSet
class SqliteMigrator:
"""
Manages migrations for a SQLite database.
:param db: The instance of :class:`SqliteDatabase` to migrate.
Migrations should be registered with :meth:`register_migration`.
Each migration is run in a transaction. If a migration fails, the transaction is rolled back.
Example Usage:
```py
db = SqliteDatabase(db_path="my_db.db", logger=logger)
migrator = SqliteMigrator(db=db)
migrator.register_migration(build_migration_1())
migrator.register_migration(build_migration_2())
migrator.run_migrations()
```
"""
backup_path: Optional[Path] = None
def __init__(self, db: SqliteDatabase) -> None:
self._db = db
self._logger = db.logger
self._migration_set = MigrationSet()
def register_migration(self, migration: Migration) -> None:
"""Registers a migration."""
self._migration_set.register(migration)
self._logger.debug(f"Registered migration {migration.from_version} -> {migration.to_version}")
def run_migrations(self) -> bool:
"""Migrates the database to the latest version."""
with self._db.lock:
# This throws if there is a problem.
self._migration_set.validate_migration_chain()
cursor = self._db.conn.cursor()
self._create_migrations_table(cursor=cursor)
if self._migration_set.count == 0:
self._logger.debug("No migrations registered")
return False
if self._get_current_version(cursor=cursor) == self._migration_set.latest_version:
self._logger.debug("Database is up to date, no migrations to run")
return False
self._logger.info("Database update needed")
next_migration = self._migration_set.get(from_version=self._get_current_version(cursor))
while next_migration is not None:
self._run_migration(next_migration)
next_migration = self._migration_set.get(self._get_current_version(cursor))
self._logger.info("Database updated successfully")
return True
def _run_migration(self, migration: Migration) -> None:
"""Runs a single migration."""
try:
# Using sqlite3.Connection as a context manager commits a the transaction on exit, or rolls it back if an
# exception is raised.
with self._db.lock, self._db.conn as conn:
cursor = conn.cursor()
if self._get_current_version(cursor) != migration.from_version:
raise MigrationError(
f"Database is at version {self._get_current_version(cursor)}, expected {migration.from_version}"
)
self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}")
# Run the actual migration
migration.callback(cursor)
# Update the version
cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
self._logger.debug(
f"Successfully migrated database from {migration.from_version} to {migration.to_version}"
)
# We want to catch *any* error, mirroring the behaviour of the sqlite3 module.
except Exception as e:
# The connection context manager has already rolled back the migration, so we don't need to do anything.
msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}"
self._logger.error(msg)
raise MigrationError(msg) from e
def _create_migrations_table(self, cursor: sqlite3.Cursor) -> None:
"""Creates the migrations table for the database, if one does not already exist."""
with self._db.lock:
try:
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
if cursor.fetchone() is not None:
return
cursor.execute(
"""--sql
CREATE TABLE migrations (
version INTEGER PRIMARY KEY,
migrated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
);
"""
)
cursor.execute("INSERT INTO migrations (version) VALUES (0);")
cursor.connection.commit()
self._logger.debug("Created migrations table")
except sqlite3.Error as e:
msg = f"Problem creating migrations table: {e}"
self._logger.error(msg)
cursor.connection.rollback()
raise MigrationError(msg) from e
@classmethod
def _get_current_version(cls, cursor: sqlite3.Cursor) -> int:
"""Gets the current version of the database, or 0 if the migrations table does not exist."""
try:
cursor.execute("SELECT MAX(version) FROM migrations;")
version: int = cursor.fetchone()[0]
if version is None:
return 0
return version
except sqlite3.OperationalError as e:
if "no such table" in str(e):
return 0
raise

View File

@ -65,12 +65,24 @@ class WorkflowWithoutID(BaseModel):
nodes: list[dict[str, JsonValue]] = Field(description="The nodes of the workflow.")
edges: list[dict[str, JsonValue]] = Field(description="The edges of the workflow.")
model_config = ConfigDict(extra="forbid")
model_config = ConfigDict(extra="ignore")
WorkflowWithoutIDValidator = TypeAdapter(WorkflowWithoutID)
class UnsafeWorkflowWithVersion(BaseModel):
"""
This utility model only requires a workflow to have a valid version string.
It is used to validate a workflow version without having to validate the entire workflow.
"""
meta: WorkflowMeta = Field(description="The meta of the workflow.")
UnsafeWorkflowWithVersionValidator = TypeAdapter(UnsafeWorkflowWithVersion)
class Workflow(WorkflowWithoutID):
id: str = Field(description="The id of the workflow.")

View File

@ -26,7 +26,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
self._create_tables()
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
@ -233,87 +232,3 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
raise
finally:
self._lock.release()
def _create_tables(self) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS workflow_library (
workflow_id TEXT NOT NULL PRIMARY KEY,
workflow TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated manually when retrieving workflow
opened_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Generated columns, needed for indexing and searching
category TEXT GENERATED ALWAYS as (json_extract(workflow, '$.meta.category')) VIRTUAL NOT NULL,
name TEXT GENERATED ALWAYS as (json_extract(workflow, '$.name')) VIRTUAL NOT NULL,
description TEXT GENERATED ALWAYS as (json_extract(workflow, '$.description')) VIRTUAL NOT NULL
);
"""
)
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_workflow_library_updated_at
AFTER UPDATE
ON workflow_library FOR EACH ROW
BEGIN
UPDATE workflow_library
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = old.workflow_id;
END;
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_library_created_at ON workflow_library(created_at);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_library_updated_at ON workflow_library(updated_at);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_library_opened_at ON workflow_library(opened_at);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_library_category ON workflow_library(category);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_library_name ON workflow_library(name);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_library_description ON workflow_library(description);
"""
)
# We do not need the original `workflows` table or `workflow_images` junction table.
self._cursor.execute(
"""--sql
DROP TABLE IF EXISTS workflow_images;
"""
)
self._cursor.execute(
"""--sql
DROP TABLE IF EXISTS workflows;
"""
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()

View File

@ -9,7 +9,7 @@ def lora_token_vector_length(checkpoint: dict) -> int:
:param checkpoint: The checkpoint
"""
def _get_shape_1(key, tensor, checkpoint):
def _get_shape_1(key: str, tensor, checkpoint) -> int:
lora_token_vector_length = None
if "." not in key:
@ -57,6 +57,10 @@ def lora_token_vector_length(checkpoint: dict) -> int:
for key, tensor in checkpoint.items():
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
elif key.startswith("lora_unet_") and (
"time_emb_proj.lora_down" in key
): # recognizes format at https://civitai.com/models/224641
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
elif key.startswith("lora_te") and "_self_attn_" in key:
tmp_length = _get_shape_1(key, tensor, checkpoint)
if key.startswith("lora_te_"):

View File

@ -49,7 +49,8 @@ class MigrateModelYamlToDb:
def get_db(self) -> ModelRecordServiceSQL:
"""Fetch the sqlite3 database for this installation."""
db = SqliteDatabase(self.config, self.logger)
db_path = None if self.config.use_memory_db else self.config.db_path
db = SqliteDatabase(db_path=db_path, logger=self.logger, verbose=self.config.log_sql)
return ModelRecordServiceSQL(db)
def get_yaml(self) -> DictConfig:

View File

@ -400,6 +400,8 @@ class LoRACheckpointProbe(CheckpointProbeBase):
return BaseModelType.StableDiffusion1
elif token_vector_length == 1024:
return BaseModelType.StableDiffusion2
elif token_vector_length == 1280:
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
elif token_vector_length == 2048:
return BaseModelType.StableDiffusionXL
else:

View File

@ -242,17 +242,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
control_model: ControlNetModel = None,
):
super().__init__(
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker,
feature_extractor,
requires_safety_checker,
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
@ -260,9 +249,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
# FIXME: can't currently register control module
# control_model=control_model,
requires_safety_checker=requires_safety_checker,
)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
self.control_model = control_model
self.use_ip_adapter = False

View File

@ -950,9 +950,9 @@
"problemSettingTitle": "Problem Setting Title",
"reloadNodeTemplates": "Reload Node Templates",
"removeLinearView": "Remove from Linear View",
"resetWorkflow": "Reset Workflow Editor",
"resetWorkflowDesc": "Are you sure you want to reset the Workflow Editor?",
"resetWorkflowDesc2": "Resetting the Workflow Editor will clear all nodes, edges and workflow details. Saved workflows will not be affected.",
"newWorkflow": "New Workflow",
"newWorkflowDesc": "Create a new workflow?",
"newWorkflowDesc2": "Your current workflow has unsaved changes.",
"scheduler": "Scheduler",
"schedulerDescription": "TODO",
"sDXLMainModelField": "SDXL Model",
@ -1634,10 +1634,10 @@
"userWorkflows": "My Workflows",
"defaultWorkflows": "Default Workflows",
"openWorkflow": "Open Workflow",
"uploadWorkflow": "Upload Workflow",
"uploadWorkflow": "Load from File",
"deleteWorkflow": "Delete Workflow",
"unnamedWorkflow": "Unnamed Workflow",
"downloadWorkflow": "Download Workflow",
"downloadWorkflow": "Save to File",
"saveWorkflow": "Save Workflow",
"saveWorkflowAs": "Save Workflow As",
"savingWorkflow": "Saving Workflow...",
@ -1652,7 +1652,7 @@
"searchWorkflows": "Search Workflows",
"clearWorkflowSearchFilter": "Clear Workflow Search Filter",
"workflowName": "Workflow Name",
"workflowEditorReset": "Workflow Editor Reset",
"newWorkflowCreated": "New Workflow Created",
"workflowEditorMenu": "Workflow Editor Menu",
"workflowIsOpen": "Workflow is Open"
},

View File

@ -727,9 +727,6 @@
"showMinimapnodes": "Mostrar el minimapa",
"reloadNodeTemplates": "Recargar las plantillas de nodos",
"loadWorkflow": "Cargar el flujo de trabajo",
"resetWorkflow": "Reiniciar e flujo de trabajo",
"resetWorkflowDesc": "¿Está seguro de que deseas restablecer este flujo de trabajo?",
"resetWorkflowDesc2": "Al reiniciar el flujo de trabajo se borrarán todos los nodos, aristas y detalles del flujo de trabajo.",
"downloadWorkflow": "Descargar el flujo de trabajo en un archivo JSON"
}
}

View File

@ -104,7 +104,16 @@
"copyError": "$t(gallery.copy) Errore",
"input": "Ingresso",
"notInstalled": "Non $t(common.installed)",
"unknownError": "Errore sconosciuto"
"unknownError": "Errore sconosciuto",
"updated": "Aggiornato",
"save": "Salva",
"created": "Creato",
"prevPage": "Pagina precedente",
"delete": "Elimina",
"orderBy": "Ordinato per",
"nextPage": "Pagina successiva",
"saveAs": "Salva come",
"unsaved": "Non salvato"
},
"gallery": {
"generations": "Generazioni",
@ -763,7 +772,10 @@
"setIPAdapterImage": "Imposta come immagine per l'Adattatore IP",
"problemSavingMaskDesc": "Impossibile salvare la maschera",
"setAsCanvasInitialImage": "Imposta come immagine iniziale della tela",
"invalidUpload": "Caricamento non valido"
"invalidUpload": "Caricamento non valido",
"problemDeletingWorkflow": "Problema durante l'eliminazione del flusso di lavoro",
"workflowDeleted": "Flusso di lavoro eliminato",
"problemRetrievingWorkflow": "Problema nel recupero del flusso di lavoro"
},
"tooltip": {
"feature": {
@ -886,11 +898,8 @@
"zoomInNodes": "Ingrandire",
"fitViewportNodes": "Adatta vista",
"showGraphNodes": "Mostra sovrapposizione grafico",
"resetWorkflowDesc2": "Reimpostare il flusso di lavoro cancellerà tutti i nodi, i bordi e i dettagli del flusso di lavoro.",
"reloadNodeTemplates": "Ricarica i modelli di nodo",
"loadWorkflow": "Importa flusso di lavoro JSON",
"resetWorkflow": "Reimposta flusso di lavoro",
"resetWorkflowDesc": "Sei sicuro di voler reimpostare questo flusso di lavoro?",
"downloadWorkflow": "Esporta flusso di lavoro JSON",
"scheduler": "Campionatore",
"addNode": "Aggiungi nodo",
@ -1080,25 +1089,27 @@
"collectionOrScalarFieldType": "{{name}} Raccolta|Scalare",
"nodeVersion": "Versione Nodo",
"inputFieldTypeParseError": "Impossibile analizzare il tipo di campo di input {{node}}.{{field}} ({{message}})",
"unsupportedArrayItemType": "tipo di elemento dell'array non supportato \"{{type}}\"",
"unsupportedArrayItemType": "Tipo di elemento dell'array non supportato \"{{type}}\"",
"targetNodeFieldDoesNotExist": "Connessione non valida: il campo di destinazione/input {{node}}.{{field}} non esiste",
"unsupportedMismatchedUnion": "tipo CollectionOrScalar non corrispondente con tipi di base {{firstType}} e {{secondType}}",
"allNodesUpdated": "Tutti i nodi sono aggiornati",
"sourceNodeDoesNotExist": "Connessione non valida: il nodo di origine/output {{node}} non esiste",
"unableToExtractEnumOptions": "impossibile estrarre le opzioni enum",
"unableToParseFieldType": "impossibile analizzare il tipo di campo",
"unableToExtractEnumOptions": "Impossibile estrarre le opzioni enum",
"unableToParseFieldType": "Impossibile analizzare il tipo di campo",
"unrecognizedWorkflowVersion": "Versione dello schema del flusso di lavoro non riconosciuta {{version}}",
"outputFieldTypeParseError": "Impossibile analizzare il tipo di campo di output {{node}}.{{field}} ({{message}})",
"sourceNodeFieldDoesNotExist": "Connessione non valida: il campo di origine/output {{node}}.{{field}} non esiste",
"unableToGetWorkflowVersion": "Impossibile ottenere la versione dello schema del flusso di lavoro",
"nodePack": "Pacchetto di nodi",
"unableToExtractSchemaNameFromRef": "impossibile estrarre il nome dello schema dal riferimento",
"unableToExtractSchemaNameFromRef": "Impossibile estrarre il nome dello schema dal riferimento",
"unknownOutput": "Output sconosciuto: {{name}}",
"unknownNodeType": "Tipo di nodo sconosciuto",
"targetNodeDoesNotExist": "Connessione non valida: il nodo di destinazione/input {{node}} non esiste",
"unknownFieldType": "$t(nodes.unknownField) tipo: {{type}}",
"deletedInvalidEdge": "Eliminata connessione non valida {{source}} -> {{target}}",
"unknownInput": "Input sconosciuto: {{name}}"
"unknownInput": "Input sconosciuto: {{name}}",
"prototypeDesc": "Questa invocazione è un prototipo. Potrebbe subire modifiche sostanziali durante gli aggiornamenti dell'app e potrebbe essere rimossa in qualsiasi momento.",
"betaDesc": "Questa invocazione è in versione beta. Fino a quando non sarà stabile, potrebbe subire modifiche importanti durante gli aggiornamenti dell'app. Abbiamo intenzione di supportare questa invocazione a lungo termine."
},
"boards": {
"autoAddBoard": "Aggiungi automaticamente bacheca",
@ -1594,5 +1605,33 @@
"hrf": "Correzione Alta Risoluzione",
"hrfStrength": "Forza della Correzione Alta Risoluzione",
"strengthTooltip": "Valori più bassi comportano meno dettagli, il che può ridurre potenziali artefatti."
},
"workflows": {
"saveWorkflowAs": "Salva flusso di lavoro come",
"workflowEditorMenu": "Menu dell'editor del flusso di lavoro",
"noSystemWorkflows": "Nessun flusso di lavoro del sistema",
"workflowName": "Nome del flusso di lavoro",
"noUserWorkflows": "Nessun flusso di lavoro utente",
"defaultWorkflows": "Flussi di lavoro predefiniti",
"saveWorkflow": "Salva flusso di lavoro",
"openWorkflow": "Apri flusso di lavoro",
"clearWorkflowSearchFilter": "Cancella il filtro di ricerca del flusso di lavoro",
"workflowLibrary": "Libreria",
"noRecentWorkflows": "Nessun flusso di lavoro recente",
"workflowSaved": "Flusso di lavoro salvato",
"workflowIsOpen": "Il flusso di lavoro è aperto",
"unnamedWorkflow": "Flusso di lavoro senza nome",
"savingWorkflow": "Salvataggio del flusso di lavoro...",
"problemLoading": "Problema durante il caricamento dei flussi di lavoro",
"loading": "Caricamento dei flussi di lavoro",
"searchWorkflows": "Cerca flussi di lavoro",
"problemSavingWorkflow": "Problema durante il salvataggio del flusso di lavoro",
"deleteWorkflow": "Elimina flusso di lavoro",
"workflows": "Flussi di lavoro",
"noDescription": "Nessuna descrizione",
"userWorkflows": "I miei flussi di lavoro"
},
"app": {
"storeNotInitialized": "Il negozio non è inizializzato"
}
}

View File

@ -844,9 +844,6 @@
"hideLegendNodes": "Typelegende veld verbergen",
"reloadNodeTemplates": "Herlaad knooppuntsjablonen",
"loadWorkflow": "Laad werkstroom",
"resetWorkflow": "Herstel werkstroom",
"resetWorkflowDesc": "Weet je zeker dat je deze werkstroom wilt herstellen?",
"resetWorkflowDesc2": "Herstel van een werkstroom haalt alle knooppunten, randen en werkstroomdetails weg.",
"downloadWorkflow": "Download JSON van werkstroom",
"booleanPolymorphicDescription": "Een verzameling Booleanse waarden.",
"scheduler": "Planner",

File diff suppressed because it is too large Load Diff

View File

@ -110,7 +110,17 @@
"copyError": "$t(gallery.copy) 错误",
"input": "输入",
"notInstalled": "非 $t(common.installed)",
"delete": "删除"
"delete": "删除",
"updated": "已上传",
"save": "保存",
"created": "已创建",
"prevPage": "上一页",
"unknownError": "未知错误",
"direction": "指向",
"orderBy": "排序方式:",
"nextPage": "下一页",
"saveAs": "保存为",
"unsaved": "未保存"
},
"gallery": {
"generations": "生成的图像",
@ -146,7 +156,11 @@
"image": "图像",
"drop": "弃用",
"dropOrUpload": "$t(gallery.drop) 或上传",
"dropToUpload": "$t(gallery.drop) 以上传"
"dropToUpload": "$t(gallery.drop) 以上传",
"problemDeletingImagesDesc": "有一张或多张图像无法被删除",
"problemDeletingImages": "删除图像时出现问题",
"unstarImage": "取消收藏图像",
"starImage": "收藏图像"
},
"hotkeys": {
"keyboardShortcuts": "键盘快捷键",
@ -724,7 +738,7 @@
"nodesUnrecognizedTypes": "无法加载。节点图有无法识别的节点类型",
"nodesNotValidJSON": "无效的 JSON",
"nodesNotValidGraph": "无效的 InvokeAi 节点图",
"nodesLoadedFailed": "节点加载失败",
"nodesLoadedFailed": "节点加载失败",
"modelAddedSimple": "已添加模型",
"modelAdded": "已添加模型: {{modelName}}",
"imageSavingFailed": "图像保存失败",
@ -760,7 +774,10 @@
"problemImportingMask": "导入遮罩时出现问题",
"baseModelChangedCleared_other": "基础模型已更改, 已清除或禁用 {{count}} 个不兼容的子模型",
"setAsCanvasInitialImage": "设为画布初始图像",
"invalidUpload": "无效的上传"
"invalidUpload": "无效的上传",
"problemDeletingWorkflow": "删除工作流时出现问题",
"workflowDeleted": "已删除工作流",
"problemRetrievingWorkflow": "检索工作流时发生问题"
},
"unifiedCanvas": {
"layer": "图层",
@ -875,11 +892,8 @@
},
"nodes": {
"zoomInNodes": "放大",
"resetWorkflowDesc": "是否确定要清空节点图?",
"resetWorkflow": "清空节点图",
"loadWorkflow": "读取节点图",
"loadWorkflow": "加载工作流",
"zoomOutNodes": "缩小",
"resetWorkflowDesc2": "重置节点图将清除所有节点、边际和节点图详情.",
"reloadNodeTemplates": "重载节点模板",
"hideGraphNodes": "隐藏节点图信息",
"fitViewportNodes": "自适应视图",
@ -888,7 +902,7 @@
"showLegendNodes": "显示字段类型图例",
"hideLegendNodes": "隐藏字段类型图例",
"showGraphNodes": "显示节点图信息",
"downloadWorkflow": "下载节点图 JSON",
"downloadWorkflow": "下载工作流 JSON",
"workflowDescription": "简述",
"versionUnknown": " 未知版本",
"noNodeSelected": "无选中的节点",
@ -1103,7 +1117,9 @@
"collectionOrScalarFieldType": "{{name}} 合集 | 标量",
"nodeVersion": "节点版本",
"deletedInvalidEdge": "已删除无效的边缘 {{source}} -> {{target}}",
"unknownInput": "未知输入:{{name}}"
"unknownInput": "未知输入:{{name}}",
"prototypeDesc": "此调用是一个原型 (prototype)。它可能会在本项目更新期间发生破坏性更改,并且随时可能被删除。",
"betaDesc": "此调用尚处于测试阶段。在稳定之前,它可能会在项目更新期间发生破坏性更改。本项目计划长期支持这种调用。"
},
"controlnet": {
"resize": "直接缩放",
@ -1607,5 +1623,35 @@
"hrf": "高分辨率修复",
"hrfStrength": "高分辨率修复强度",
"strengthTooltip": "值越低细节越少,但可以减少部分潜在的伪影。"
},
"workflows": {
"saveWorkflowAs": "保存工作流为",
"workflowEditorMenu": "工作流编辑器菜单",
"noSystemWorkflows": "无系统工作流",
"workflowName": "工作流名称",
"noUserWorkflows": "无用户工作流",
"defaultWorkflows": "默认工作流",
"saveWorkflow": "保存工作流",
"openWorkflow": "打开工作流",
"clearWorkflowSearchFilter": "清除工作流检索过滤器",
"workflowLibrary": "工作流库",
"downloadWorkflow": "下载工作流",
"noRecentWorkflows": "无最近工作流",
"workflowSaved": "已保存工作流",
"workflowIsOpen": "工作流已打开",
"unnamedWorkflow": "未命名的工作流",
"savingWorkflow": "保存工作流中...",
"problemLoading": "加载工作流时出现问题",
"loading": "加载工作流中",
"searchWorkflows": "检索工作流",
"problemSavingWorkflow": "保存工作流时出现问题",
"deleteWorkflow": "删除工作流",
"workflows": "工作流",
"noDescription": "无描述",
"uploadWorkflow": "上传工作流",
"userWorkflows": "我的工作流"
},
"app": {
"storeNotInitialized": "商店尚未初始化"
}
}

View File

@ -144,6 +144,7 @@ export const buildCanvasImageToImageGraph = (
type: 'l2i',
id: CANVAS_OUTPUT,
is_intermediate,
use_cache: false,
},
},
edges: [
@ -255,6 +256,7 @@ export const buildCanvasImageToImageGraph = (
is_intermediate,
width: width,
height: height,
use_cache: false,
};
graph.edges.push(
@ -295,6 +297,7 @@ export const buildCanvasImageToImageGraph = (
id: CANVAS_OUTPUT,
is_intermediate,
fp32,
use_cache: false,
};
(graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image =

View File

@ -191,6 +191,7 @@ export const buildCanvasInpaintGraph = (
id: CANVAS_OUTPUT,
is_intermediate,
reference: canvasInitImage,
use_cache: false,
},
},
edges: [

View File

@ -199,6 +199,7 @@ export const buildCanvasOutpaintGraph = (
type: 'color_correct',
id: CANVAS_OUTPUT,
is_intermediate,
use_cache: false,
},
},
edges: [

View File

@ -266,6 +266,7 @@ export const buildCanvasSDXLImageToImageGraph = (
is_intermediate,
width: width,
height: height,
use_cache: false,
};
graph.edges.push(
@ -306,6 +307,7 @@ export const buildCanvasSDXLImageToImageGraph = (
id: CANVAS_OUTPUT,
is_intermediate,
fp32,
use_cache: false,
};
(graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image =

View File

@ -196,6 +196,7 @@ export const buildCanvasSDXLInpaintGraph = (
id: CANVAS_OUTPUT,
is_intermediate,
reference: canvasInitImage,
use_cache: false,
},
},
edges: [

View File

@ -204,6 +204,7 @@ export const buildCanvasSDXLOutpaintGraph = (
type: 'color_correct',
id: CANVAS_OUTPUT,
is_intermediate,
use_cache: false,
},
},
edges: [

View File

@ -258,6 +258,7 @@ export const buildCanvasSDXLTextToImageGraph = (
is_intermediate,
width: width,
height: height,
use_cache: false,
};
graph.edges.push(
@ -288,6 +289,7 @@ export const buildCanvasSDXLTextToImageGraph = (
id: CANVAS_OUTPUT,
is_intermediate,
fp32,
use_cache: false,
};
graph.edges.push({

View File

@ -246,6 +246,7 @@ export const buildCanvasTextToImageGraph = (
is_intermediate,
width: width,
height: height,
use_cache: false,
};
graph.edges.push(
@ -276,6 +277,7 @@ export const buildCanvasTextToImageGraph = (
id: CANVAS_OUTPUT,
is_intermediate,
fp32,
use_cache: false,
};
graph.edges.push({

View File

@ -143,6 +143,7 @@ export const buildLinearImageToImageGraph = (
// },
fp32,
is_intermediate,
use_cache: false,
},
},
edges: [

View File

@ -154,6 +154,7 @@ export const buildLinearSDXLImageToImageGraph = (
// },
fp32,
is_intermediate,
use_cache: false,
},
},
edges: [

View File

@ -127,6 +127,7 @@ export const buildLinearSDXLTextToImageGraph = (
id: LATENTS_TO_IMAGE,
fp32,
is_intermediate,
use_cache: false,
},
},
edges: [

View File

@ -146,6 +146,7 @@ export const buildLinearTextToImageGraph = (
id: LATENTS_TO_IMAGE,
fp32,
is_intermediate,
use_cache: false,
},
},
edges: [

View File

@ -64,7 +64,10 @@ const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => {
const nodePack = invocationTemplate
? invocationTemplate.nodePack
: t('common.unknown');
(node.data as unknown as InvocationNodeData).nodePack = nodePack;
// Fallback to 1.0.0 if not specified - this matches the behavior of the backend
node.data.version ||= '1.0.0';
}
});
// Bump version

View File

@ -11,44 +11,48 @@ import {
Text,
useDisclosure,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { memo, useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { FaTrash } from 'react-icons/fa';
import { FaCircleNodes } from 'react-icons/fa6';
const ResetWorkflowEditorMenuItem = () => {
const NewWorkflowMenuItem = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { isOpen, onOpen, onClose } = useDisclosure();
const cancelRef = useRef<HTMLButtonElement | null>(null);
const isTouched = useAppSelector((state) => state.workflow.isTouched);
const handleConfirmClear = useCallback(() => {
const handleNewWorkflow = useCallback(() => {
dispatch(nodeEditorReset());
dispatch(
addToast(
makeToast({
title: t('workflows.workflowEditorReset'),
title: t('workflows.newWorkflowCreated'),
status: 'success',
})
)
);
onClose();
}, [dispatch, t, onClose]);
}, [dispatch, onClose, t]);
const onClick = useCallback(() => {
if (!isTouched) {
handleNewWorkflow();
return;
}
onOpen();
}, [handleNewWorkflow, isTouched, onOpen]);
return (
<>
<MenuItem
as="button"
icon={<FaTrash />}
sx={{ color: 'error.600', _dark: { color: 'error.300' } }}
onClick={onOpen}
>
{t('nodes.resetWorkflow')}
<MenuItem as="button" icon={<FaCircleNodes />} onClick={onClick}>
{t('nodes.newWorkflow')}
</MenuItem>
<AlertDialog
@ -61,13 +65,13 @@ const ResetWorkflowEditorMenuItem = () => {
<AlertDialogContent>
<AlertDialogHeader fontSize="lg" fontWeight="bold">
{t('nodes.resetWorkflow')}
{t('nodes.newWorkflow')}
</AlertDialogHeader>
<AlertDialogBody py={4}>
<Flex flexDir="column" gap={2}>
<Text>{t('nodes.resetWorkflowDesc')}</Text>
<Text variant="subtext">{t('nodes.resetWorkflowDesc2')}</Text>
<Text>{t('nodes.newWorkflowDesc')}</Text>
<Text variant="subtext">{t('nodes.newWorkflowDesc2')}</Text>
</Flex>
</AlertDialogBody>
@ -75,7 +79,7 @@ const ResetWorkflowEditorMenuItem = () => {
<Button ref={cancelRef} onClick={onClose}>
{t('common.cancel')}
</Button>
<Button colorScheme="error" ml={3} onClick={handleConfirmClear}>
<Button colorScheme="error" ml={3} onClick={handleNewWorkflow}>
{t('common.accept')}
</Button>
</AlertDialogFooter>
@ -85,4 +89,4 @@ const ResetWorkflowEditorMenuItem = () => {
);
};
export default memo(ResetWorkflowEditorMenuItem);
export default memo(NewWorkflowMenuItem);

View File

@ -9,7 +9,7 @@ import IAIIconButton from 'common/components/IAIIconButton';
import { useGlobalMenuCloseTrigger } from 'common/hooks/useGlobalMenuCloseTrigger';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import DownloadWorkflowMenuItem from 'features/workflowLibrary/components/WorkflowLibraryMenu/DownloadWorkflowMenuItem';
import ResetWorkflowEditorMenuItem from 'features/workflowLibrary/components/WorkflowLibraryMenu/ResetWorkflowEditorMenuItem';
import NewWorkflowMenuItem from 'features/workflowLibrary/components/WorkflowLibraryMenu/NewWorkflowMenuItem';
import SaveWorkflowAsMenuItem from 'features/workflowLibrary/components/WorkflowLibraryMenu/SaveWorkflowAsMenuItem';
import SaveWorkflowMenuItem from 'features/workflowLibrary/components/WorkflowLibraryMenu/SaveWorkflowMenuItem';
import SettingsMenuItem from 'features/workflowLibrary/components/WorkflowLibraryMenu/SettingsMenuItem';
@ -39,7 +39,7 @@ const WorkflowLibraryMenu = () => {
{isWorkflowLibraryEnabled && <SaveWorkflowAsMenuItem />}
<DownloadWorkflowMenuItem />
<UploadWorkflowMenuItem />
<ResetWorkflowEditorMenuItem />
<NewWorkflowMenuItem />
<MenuDivider />
<SettingsMenuItem />
</MenuList>

View File

@ -101,6 +101,8 @@ plugins:
extra_javascript:
- https://unpkg.com/tablesort@5.3.0/dist/tablesort.min.js
- javascripts/tablesort.js
- https://widget.kapa.ai/kapa-widget.bundle.js
- javascript/init_kapa_widget.js
extra:
analytics:

View File

@ -32,7 +32,7 @@ classifiers = [
'Topic :: Scientific/Engineering :: Image Processing',
]
dependencies = [
"accelerate~=0.24.0",
"accelerate~=0.25.0",
"albumentations",
"basicsr",
"click",
@ -41,15 +41,15 @@ dependencies = [
"controlnet-aux>=0.0.6",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"datasets",
"diffusers[torch]~=0.23.0",
"diffusers[torch]~=0.24.0",
"dnspython~=2.4.0",
"dynamicprompts",
"easing-functions",
"einops",
"facexlib",
"fastapi~=0.104.1",
"fastapi~=0.105.0",
"fastapi-events~=0.9.1",
"huggingface-hub~=0.16.4",
"huggingface-hub~=0.19.4",
"imohash",
"invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids
"matplotlib", # needed for plotting of Penner easing functions
@ -80,11 +80,11 @@ dependencies = [
"semver~=3.0.1",
"send2trash",
"test-tube~=0.7.5",
"torch==2.1.0",
"torchvision==0.16.0",
"torch==2.1.1",
"torchvision==0.16.1",
"torchmetrics~=0.11.0",
"torchsde~=0.2.5",
"transformers~=4.35.0",
"transformers~=4.36.0",
"uvicorn[standard]~=0.21.1",
"windows-curses; sys_platform=='win32'",
]
@ -107,7 +107,7 @@ dependencies = [
"pytest-datadir",
]
"xformers" = [
"xformers==0.0.22post7; sys_platform!='darwin'",
"xformers==0.0.23; sys_platform!='darwin'",
"triton; sys_platform=='linux'",
]
"onnx" = ["onnxruntime"]

View File

@ -28,8 +28,8 @@ from invokeai.app.services.shared.graph import (
IterateInvocation,
LibraryGraph,
)
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import create_mock_sqlite_database
from .test_invoker import create_edge
@ -49,7 +49,8 @@ def simple_graph():
@pytest.fixture
def mock_services() -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
db = SqliteDatabase(configuration, InvokeAILogger.get_logger())
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(configuration, logger)
# NOTE: none of these are actually called by the test invocations
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
return InvocationServices(

View File

@ -4,6 +4,7 @@ import pytest
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import create_mock_sqlite_database
# This import must happen before other invoke imports or test in other files(!!) break
from .test_nodes import ( # isort: split
@ -24,7 +25,6 @@ from invokeai.app.services.invoker import Invoker
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
@pytest.fixture
@ -52,8 +52,9 @@ def graph_with_subgraph():
# the test invocations.
@pytest.fixture
def mock_services() -> InvocationServices:
db = SqliteDatabase(InvokeAIAppConfig(use_memory_db=True), InvokeAILogger.get_logger())
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(configuration, logger)
# NOTE: none of these are actually called by the test invocations
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")

View File

@ -15,8 +15,11 @@ class TestModel(BaseModel):
@pytest.fixture
def db() -> SqliteItemStorage[TestModel]:
sqlite_db = SqliteDatabase(InvokeAIAppConfig(use_memory_db=True), InvokeAILogger.get_logger())
sqlite_item_storage = SqliteItemStorage[TestModel](db=sqlite_db, table_name="test", id_field="id")
config = InvokeAIAppConfig(use_memory_db=True)
logger = InvokeAILogger.get_logger()
db_path = None if config.use_memory_db else config.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
sqlite_item_storage = SqliteItemStorage[TestModel](db=db, table_name="test", id_field="id")
return sqlite_item_storage

View File

@ -18,9 +18,9 @@ from invokeai.app.services.model_install import (
ModelInstallServiceBase,
)
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.config import BaseModelType, ModelType
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import create_mock_sqlite_database
@pytest.fixture
@ -37,9 +37,12 @@ def app_config(datadir: Path) -> InvokeAIAppConfig:
@pytest.fixture
def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
database = SqliteDatabase(app_config, InvokeAILogger.get_logger(config=app_config))
store: ModelRecordServiceBase = ModelRecordServiceSQL(database)
def store(
app_config: InvokeAIAppConfig,
) -> ModelRecordServiceBase:
logger = InvokeAILogger.get_logger(config=app_config)
db = create_mock_sqlite_database(app_config, logger)
store: ModelRecordServiceBase = ModelRecordServiceSQL(db)
return store

View File

@ -3,6 +3,7 @@ Test the refactored model config classes.
"""
from hashlib import sha256
from typing import Any
import pytest
@ -13,7 +14,6 @@ from invokeai.app.services.model_records import (
ModelRecordServiceSQL,
UnknownModelException,
)
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.config import (
BaseModelType,
MainCheckpointConfig,
@ -23,13 +23,16 @@ from invokeai.backend.model_manager.config import (
VaeDiffusersConfig,
)
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import create_mock_sqlite_database
@pytest.fixture
def store(datadir) -> ModelRecordServiceBase:
def store(
datadir: Any,
) -> ModelRecordServiceBase:
config = InvokeAIAppConfig(root=datadir)
logger = InvokeAILogger.get_logger(config=config)
db = SqliteDatabase(config, logger)
db = create_mock_sqlite_database(config, logger)
return ModelRecordServiceSQL(db)

0
tests/fixtures/__init__.py vendored Normal file
View File

13
tests/fixtures/sqlite_database.py vendored Normal file
View File

@ -0,0 +1,13 @@
from logging import Logger
from unittest import mock
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
def create_mock_sqlite_database(config: InvokeAIAppConfig, logger: Logger) -> SqliteDatabase:
image_files = mock.Mock(spec=ImageFileStorageBase)
db = init_db(config=config, logger=logger, image_files=image_files)
return db

View File

@ -0,0 +1,272 @@
import sqlite3
from contextlib import closing
from logging import Logger
from pathlib import Path
from tempfile import TemporaryDirectory
import pytest
from pydantic import ValidationError
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import (
MigrateCallback,
Migration,
MigrationError,
MigrationSet,
MigrationVersionError,
)
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import (
SqliteMigrator,
)
@pytest.fixture
def logger() -> Logger:
return Logger("test_sqlite_migrator")
@pytest.fixture
def memory_db_conn() -> sqlite3.Connection:
return sqlite3.connect(":memory:")
@pytest.fixture
def memory_db_cursor(memory_db_conn: sqlite3.Connection) -> sqlite3.Cursor:
return memory_db_conn.cursor()
@pytest.fixture
def migrator(logger: Logger) -> SqliteMigrator:
db = SqliteDatabase(db_path=None, logger=logger, verbose=False)
return SqliteMigrator(db=db)
@pytest.fixture
def no_op_migrate_callback() -> MigrateCallback:
def no_op_migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
pass
return no_op_migrate
@pytest.fixture
def migration_no_op(no_op_migrate_callback: MigrateCallback) -> Migration:
return Migration(from_version=0, to_version=1, callback=no_op_migrate_callback)
@pytest.fixture
def migrate_callback_create_table_of_name() -> MigrateCallback:
def migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
table_name = kwargs["table_name"]
cursor.execute(f"CREATE TABLE {table_name} (id INTEGER PRIMARY KEY);")
return migrate
@pytest.fixture
def migrate_callback_create_test_table() -> MigrateCallback:
def migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
return migrate
@pytest.fixture
def migration_create_test_table(migrate_callback_create_test_table: MigrateCallback) -> Migration:
return Migration(from_version=0, to_version=1, callback=migrate_callback_create_test_table)
@pytest.fixture
def failing_migration() -> Migration:
def failing_migration(cursor: sqlite3.Cursor, **kwargs) -> None:
raise Exception("Bad migration")
return Migration(from_version=0, to_version=1, callback=failing_migration)
@pytest.fixture
def failing_migrate_callback() -> MigrateCallback:
def failing_migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
raise Exception("Bad migration")
return failing_migrate
def create_migrate(i: int) -> MigrateCallback:
def migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
cursor.execute(f"CREATE TABLE test{i} (id INTEGER PRIMARY KEY);")
return migrate
def test_migration_to_version_is_one_gt_from_version(no_op_migrate_callback: MigrateCallback) -> None:
with pytest.raises(ValidationError, match="to_version must be one greater than from_version"):
Migration(from_version=0, to_version=2, callback=no_op_migrate_callback)
# not raising is sufficient
Migration(from_version=1, to_version=2, callback=no_op_migrate_callback)
def test_migration_hash(no_op_migrate_callback: MigrateCallback) -> None:
migration = Migration(from_version=0, to_version=1, callback=no_op_migrate_callback)
assert hash(migration) == hash((0, 1))
def test_migration_set_add_migration(migrator: SqliteMigrator, migration_no_op: Migration) -> None:
migration = migration_no_op
migrator._migration_set.register(migration)
assert migration in migrator._migration_set._migrations
def test_migration_set_may_not_register_dupes(
migrator: SqliteMigrator, no_op_migrate_callback: MigrateCallback
) -> None:
migrate_0_to_1_a = Migration(from_version=0, to_version=1, callback=no_op_migrate_callback)
migrate_0_to_1_b = Migration(from_version=0, to_version=1, callback=no_op_migrate_callback)
migrator._migration_set.register(migrate_0_to_1_a)
with pytest.raises(MigrationVersionError, match=r"Migration with from_version or to_version already registered"):
migrator._migration_set.register(migrate_0_to_1_b)
migrate_1_to_2_a = Migration(from_version=1, to_version=2, callback=no_op_migrate_callback)
migrate_1_to_2_b = Migration(from_version=1, to_version=2, callback=no_op_migrate_callback)
migrator._migration_set.register(migrate_1_to_2_a)
with pytest.raises(MigrationVersionError, match=r"Migration with from_version or to_version already registered"):
migrator._migration_set.register(migrate_1_to_2_b)
def test_migration_set_gets_migration(migration_no_op: Migration) -> None:
migration_set = MigrationSet()
migration_set.register(migration_no_op)
assert migration_set.get(0) == migration_no_op
assert migration_set.get(1) is None
def test_migration_set_validates_migration_chain(no_op_migrate_callback: MigrateCallback) -> None:
migration_set = MigrationSet()
migration_set.register(Migration(from_version=1, to_version=2, callback=no_op_migrate_callback))
with pytest.raises(MigrationError, match="Migration chain is fragmented"):
# no migration from 0 to 1
migration_set.validate_migration_chain()
migration_set.register(Migration(from_version=0, to_version=1, callback=no_op_migrate_callback))
migration_set.validate_migration_chain()
migration_set.register(Migration(from_version=2, to_version=3, callback=no_op_migrate_callback))
migration_set.validate_migration_chain()
migration_set.register(Migration(from_version=4, to_version=5, callback=no_op_migrate_callback))
with pytest.raises(MigrationError, match="Migration chain is fragmented"):
# no migration from 3 to 4
migration_set.validate_migration_chain()
def test_migration_set_counts_migrations(no_op_migrate_callback: MigrateCallback) -> None:
migration_set = MigrationSet()
assert migration_set.count == 0
migration_set.register(Migration(from_version=0, to_version=1, callback=no_op_migrate_callback))
assert migration_set.count == 1
migration_set.register(Migration(from_version=1, to_version=2, callback=no_op_migrate_callback))
assert migration_set.count == 2
def test_migration_set_gets_latest_version(no_op_migrate_callback: MigrateCallback) -> None:
migration_set = MigrationSet()
assert migration_set.latest_version == 0
migration_set.register(Migration(from_version=1, to_version=2, callback=no_op_migrate_callback))
assert migration_set.latest_version == 2
migration_set.register(Migration(from_version=0, to_version=1, callback=no_op_migrate_callback))
assert migration_set.latest_version == 2
def test_migration_runs(memory_db_cursor: sqlite3.Cursor, migrate_callback_create_test_table: MigrateCallback) -> None:
migration = Migration(
from_version=0,
to_version=1,
callback=migrate_callback_create_test_table,
)
migration.callback(memory_db_cursor)
memory_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
assert memory_db_cursor.fetchone() is not None
def test_migrator_registers_migration(migrator: SqliteMigrator, migration_no_op: Migration) -> None:
migration = migration_no_op
migrator.register_migration(migration)
assert migration in migrator._migration_set._migrations
def test_migrator_creates_migrations_table(migrator: SqliteMigrator) -> None:
cursor = migrator._db.conn.cursor()
migrator._create_migrations_table(cursor)
cursor.execute("SELECT * FROM sqlite_master WHERE type='table' AND name='migrations';")
assert cursor.fetchone() is not None
def test_migrator_migration_sets_version(migrator: SqliteMigrator, migration_no_op: Migration) -> None:
cursor = migrator._db.conn.cursor()
migrator._create_migrations_table(cursor)
migrator.register_migration(migration_no_op)
migrator.run_migrations()
cursor.execute("SELECT MAX(version) FROM migrations;")
assert cursor.fetchone()[0] == 1
def test_migrator_gets_current_version(migrator: SqliteMigrator, migration_no_op: Migration) -> None:
cursor = migrator._db.conn.cursor()
assert migrator._get_current_version(cursor) == 0
migrator._create_migrations_table(cursor)
assert migrator._get_current_version(cursor) == 0
migrator.register_migration(migration_no_op)
migrator.run_migrations()
assert migrator._get_current_version(cursor) == 1
def test_migrator_runs_single_migration(migrator: SqliteMigrator, migration_create_test_table: Migration) -> None:
cursor = migrator._db.conn.cursor()
migrator._create_migrations_table(cursor)
migrator._run_migration(migration_create_test_table)
assert migrator._get_current_version(cursor) == 1
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
assert cursor.fetchone() is not None
def test_migrator_runs_all_migrations_in_memory(migrator: SqliteMigrator) -> None:
cursor = migrator._db.conn.cursor()
migrations = [Migration(from_version=i, to_version=i + 1, callback=create_migrate(i)) for i in range(0, 3)]
for migration in migrations:
migrator.register_migration(migration)
migrator.run_migrations()
assert migrator._get_current_version(cursor) == 3
def test_migrator_runs_all_migrations_file(logger: Logger) -> None:
with TemporaryDirectory() as tempdir:
original_db_path = Path(tempdir) / "invokeai.db"
db = SqliteDatabase(db_path=original_db_path, logger=logger, verbose=False)
migrator = SqliteMigrator(db=db)
migrations = [Migration(from_version=i, to_version=i + 1, callback=create_migrate(i)) for i in range(0, 3)]
for migration in migrations:
migrator.register_migration(migration)
migrator.run_migrations()
with closing(sqlite3.connect(original_db_path)) as original_db_conn:
original_db_cursor = original_db_conn.cursor()
assert SqliteMigrator._get_current_version(original_db_cursor) == 3
# Must manually close else we get an error on Windows
db.conn.close()
def test_migrator_makes_no_changes_on_failed_migration(
migrator: SqliteMigrator, migration_no_op: Migration, failing_migrate_callback: MigrateCallback
) -> None:
cursor = migrator._db.conn.cursor()
migrator.register_migration(migration_no_op)
migrator.run_migrations()
assert migrator._get_current_version(cursor) == 1
migrator.register_migration(Migration(from_version=1, to_version=2, callback=failing_migrate_callback))
with pytest.raises(MigrationError, match="Bad migration"):
migrator.run_migrations()
assert migrator._get_current_version(cursor) == 1
def test_idempotent_migrations(migrator: SqliteMigrator, migration_create_test_table: Migration) -> None:
cursor = migrator._db.conn.cursor()
migrator.register_migration(migration_create_test_table)
migrator.run_migrations()
# not throwing is sufficient
migrator.run_migrations()
assert migrator._get_current_version(cursor) == 1