Merge branch 'main' into sdxl-convert-safetensors

This commit is contained in:
Brandon 2024-02-02 10:10:49 -05:00 committed by GitHub
commit 72db2ee352
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 437 additions and 384 deletions

View File

@ -169,7 +169,7 @@ the command `npm install -g pnpm` if needed)
_For Linux with an AMD GPU:_
```sh
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.6
```
_For non-GPU systems:_

View File

@ -477,7 +477,7 @@ Then type the following commands:
=== "AMD System"
```bash
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/rocm5.6
```
### Corrupted configuration file

View File

@ -154,7 +154,7 @@ manager, please follow these steps:
=== "ROCm (AMD)"
```bash
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.6
```
=== "CPU (Intel Macs & non-GPU systems)"
@ -313,7 +313,7 @@ code for InvokeAI. For this to work, you will need to install the
on your system, please see the [Git Installation
Guide](https://github.com/git-guides/install-git)
You will also need to install the [frontend development toolchain](https://github.com/invoke-ai/InvokeAI/blob/main/docs/contributing/contribution_guides/contributingToFrontend.md).
You will also need to install the [frontend development toolchain](https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/README.md).
If you have a "normal" installation, you should create a totally separate virtual environment for the git-based installation, else the two may interfere.
@ -345,7 +345,7 @@ installation protocol (important!)
=== "ROCm (AMD)"
```bash
pip install -e . --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
pip install -e . --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.6
```
=== "CPU (Intel Macs & non-GPU systems)"
@ -361,7 +361,7 @@ installation protocol (important!)
Be sure to pass `-e` (for an editable install) and don't forget the
dot ("."). It is part of the command.
5. Install the [frontend toolchain](https://github.com/invoke-ai/InvokeAI/blob/main/docs/contributing/contribution_guides/contributingToFrontend.md) and do a production build of the UI as described.
5. Install the [frontend toolchain](https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/README.md) and do a production build of the UI as described.
6. You can now run `invokeai` and its related commands. The code will be
read from the repository, so that you can edit the .py source files

View File

@ -134,7 +134,7 @@ recipes are available
When installing torch and torchvision manually with `pip`, remember to provide
the argument `--extra-index-url
https://download.pytorch.org/whl/rocm5.4.2` as described in the [Manual
https://download.pytorch.org/whl/rocm5.6` as described in the [Manual
Installation Guide](020_INSTALL_MANUAL.md).
This will be done automatically for you if you use the installer

View File

@ -25,6 +25,7 @@ To use a community workflow, download the the `.json` node graph file and load i
+ [GPT2RandomPromptMaker](#gpt2randompromptmaker)
+ [Grid to Gif](#grid-to-gif)
+ [Halftone](#halftone)
+ [Hand Refiner with MeshGraphormer](#hand-refiner-with-meshgraphormer)
+ [Image and Mask Composition Pack](#image-and-mask-composition-pack)
+ [Image Dominant Color](#image-dominant-color)
+ [Image to Character Art Image Nodes](#image-to-character-art-image-nodes)
@ -196,6 +197,18 @@ CMYK Halftone Output:
<img src="https://github.com/invoke-ai/InvokeAI/assets/34005131/c59c578f-db8e-4d66-8c66-2851752d75ea" width="300" />
--------------------------------
### Hand Refiner with MeshGraphormer
**Description**: Hand Refiner takes in your image and automatically generates a fixed depth map for the hands along with a mask of the hands region that will conveniently allow you to use them along with ControlNet to fix the wonky hands generated by Stable Diffusion
**Node Link:** https://github.com/blessedcoolant/invoke_meshgraphormer
**View**
<img src="https://raw.githubusercontent.com/blessedcoolant/invoke_meshgraphormer/main/assets/preview.jpg" />
--------------------------------
### Image and Mask Composition Pack
**Description:** This is a pack of nodes for composing masks and images, including a simple text mask creator and both image and latent offset nodes. The offsets wrap around, so these can be used in conjunction with the Seamless node to progressively generate centered on different parts of the seamless tiling.

View File

@ -455,7 +455,7 @@ def get_torch_source() -> (Union[str, None], str):
optional_modules = "[onnx]"
if OS == "Linux":
if device == "rocm":
url = "https://download.pytorch.org/whl/rocm5.4.2"
url = "https://download.pytorch.org/whl/rocm5.6"
elif device == "cpu":
url = "https://download.pytorch.org/whl/cpu"

View File

@ -2,6 +2,7 @@
from logging import Logger
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.util.logging import InvokeAILogger
@ -22,7 +23,6 @@ from ..services.invocation_queue.invocation_queue_memory import MemoryInvocation
from ..services.invocation_services import InvocationServices
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
from ..services.invoker import Invoker
from ..services.item_storage.item_storage_sqlite import SqliteItemStorage
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
from ..services.model_install import ModelInstallService
@ -80,7 +80,7 @@ class ApiDependencies:
board_records = SqliteBoardRecordStorage(db=db)
boards = BoardService()
events = FastAPIEventService(event_handler_id)
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
graph_execution_manager = ItemStorageMemory[GraphExecutionState]()
image_records = SqliteImageRecordStorage(db=db)
images = ImageService()
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)

View File

@ -274,7 +274,7 @@ class InvokeAIAppConfig(InvokeAISettings):
attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type", json_schema_extra=Categories.Generation)
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', json_schema_extra=Categories.Generation)
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.Generation)
png_compress_level : int = Field(default=6, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize", json_schema_extra=Categories.Generation)
png_compress_level : int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize", json_schema_extra=Categories.Generation)
# QUEUE
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", json_schema_extra=Categories.Queue)

View File

@ -1,10 +1,8 @@
from abc import ABC, abstractmethod
from typing import Callable, Generic, Optional, TypeVar
from typing import Callable, Generic, TypeVar
from pydantic import BaseModel
from invokeai.app.services.shared.pagination import PaginatedResults
T = TypeVar("T", bound=BaseModel)
@ -25,23 +23,14 @@ class ItemStorageABC(ABC, Generic[T]):
"""Gets the item, parsing it into a Pydantic model"""
pass
@abstractmethod
def get_raw(self, item_id: str) -> Optional[str]:
"""Gets the raw item as a string, skipping Pydantic parsing"""
pass
@abstractmethod
def set(self, item: T) -> None:
"""Sets the item"""
pass
@abstractmethod
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
"""Gets a paginated list of items"""
pass
@abstractmethod
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
def delete(self, item_id: str) -> None:
"""Deletes the item"""
pass
def on_changed(self, on_changed: Callable[[T], None]) -> None:

View File

@ -0,0 +1,50 @@
from collections import OrderedDict
from contextlib import suppress
from typing import Generic, Optional, TypeVar
from pydantic import BaseModel
from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC
T = TypeVar("T", bound=BaseModel)
class ItemStorageMemory(ItemStorageABC, Generic[T]):
"""
Provides a simple in-memory storage for items, with a maximum number of items to store.
The storage uses the LRU strategy to evict items from storage when the max has been reached.
"""
def __init__(self, id_field: str = "id", max_items: int = 10) -> None:
super().__init__()
if max_items < 1:
raise ValueError("max_items must be at least 1")
if not id_field:
raise ValueError("id_field must not be empty")
self._id_field = id_field
self._items: OrderedDict[str, T] = OrderedDict()
self._max_items = max_items
def get(self, item_id: str) -> Optional[T]:
# If the item exists, move it to the end of the OrderedDict.
item = self._items.pop(item_id, None)
if item is not None:
self._items[item_id] = item
return self._items.get(item_id)
def set(self, item: T) -> None:
item_id = getattr(item, self._id_field)
if item_id in self._items:
# If item already exists, remove it and add it to the end
self._items.pop(item_id)
elif len(self._items) >= self._max_items:
# If cache is full, evict the least recently used item
self._items.popitem(last=False)
self._items[item_id] = item
self._on_changed(item)
def delete(self, item_id: str) -> None:
# This is a no-op if the item doesn't exist.
with suppress(KeyError):
del self._items[item_id]
self._on_deleted(item_id)

View File

@ -1,147 +0,0 @@
import sqlite3
import threading
from typing import Generic, Optional, TypeVar, get_args
from pydantic import BaseModel, TypeAdapter
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from .item_storage_base import ItemStorageABC
T = TypeVar("T", bound=BaseModel)
class SqliteItemStorage(ItemStorageABC, Generic[T]):
_table_name: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_id_field: str
_lock: threading.RLock
_validator: Optional[TypeAdapter[T]]
def __init__(self, db: SqliteDatabase, table_name: str, id_field: str = "id"):
super().__init__()
self._lock = db.lock
self._conn = db.conn
self._table_name = table_name
self._id_field = id_field # TODO: validate that T has this field
self._cursor = self._conn.cursor()
self._validator: Optional[TypeAdapter[T]] = None
self._create_table()
def _create_table(self):
try:
self._lock.acquire()
self._cursor.execute(
f"""CREATE TABLE IF NOT EXISTS {self._table_name} (
item TEXT,
id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);"""
)
self._cursor.execute(
f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);"""
)
finally:
self._lock.release()
def _parse_item(self, item: str) -> T:
if self._validator is None:
"""
We don't get access to `__orig_class__` in `__init__()`, and we need this before start(), so
we can create it when it is first needed instead.
__orig_class__ is technically an implementation detail of the typing module, not a supported API
"""
self._validator = TypeAdapter(get_args(self.__orig_class__)[0]) # type: ignore [attr-defined]
return self._validator.validate_json(item)
def set(self, item: T):
try:
self._lock.acquire()
self._cursor.execute(
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
(item.model_dump_json(warnings=False, exclude_none=True),),
)
self._conn.commit()
finally:
self._lock.release()
self._on_changed(item)
def get(self, id: str) -> Optional[T]:
try:
self._lock.acquire()
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
result = self._cursor.fetchone()
finally:
self._lock.release()
if not result:
return None
return self._parse_item(result[0])
def get_raw(self, id: str) -> Optional[str]:
try:
self._lock.acquire()
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
result = self._cursor.fetchone()
finally:
self._lock.release()
if not result:
return None
return result[0]
def delete(self, id: str):
try:
self._lock.acquire()
self._cursor.execute(f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),))
self._conn.commit()
finally:
self._lock.release()
self._on_deleted(id)
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
try:
self._lock.acquire()
self._cursor.execute(
f"""SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;""",
(per_page, page * per_page),
)
result = self._cursor.fetchall()
items = [self._parse_item(r[0]) for r in result]
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
try:
self._lock.acquire()
self._cursor.execute(
f"""SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;""",
(f"%{query}%", per_page, page * per_page),
)
result = self._cursor.fetchall()
items = [self._parse_item(r[0]) for r in result]
self._cursor.execute(
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
(f"%{query}%",),
)
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count)

View File

@ -7,6 +7,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import build_migration_2
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import build_migration_3
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@ -31,6 +32,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_2(image_files=image_files, logger=logger))
migrator.register_migration(build_migration_3(app_config=config, logger=logger))
migrator.register_migration(build_migration_4())
migrator.register_migration(build_migration_5())
migrator.run_migrations()
return db

View File

@ -0,0 +1,34 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration5Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._drop_graph_executions(cursor)
def _drop_graph_executions(self, cursor: sqlite3.Cursor) -> None:
"""Drops the `graph_executions` table."""
cursor.execute(
"""--sql
DROP TABLE IF EXISTS graph_executions;
"""
)
def build_migration_5() -> Migration:
"""
Build the migration from database version 4 to 5.
Introduced in v3.6.3, this migration:
- Drops the `graph_executions` table. We are able to do this because we are moving the graph storage
to be purely in-memory.
"""
migration_5 = Migration(
from_version=4,
to_version=5,
callback=Migration5Callback(),
)
return migration_5

View File

@ -12,7 +12,7 @@ import psutil
import torch
from compel.cross_attention_control import Arguments
from diffusers.models.attention_processor import Attention, AttentionProcessor, AttnProcessor, SlicedAttnProcessor
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from torch import nn
import invokeai.backend.util.logging as logger

View File

@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import diffusers
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalControlnetMixin
from diffusers.loaders import FromOriginalControlNetMixin
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
from diffusers.models.embeddings import (
@ -14,8 +14,13 @@ from diffusers.models.embeddings import (
Timesteps,
)
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, DownBlock2D, UNetMidBlock2DCrossAttn, get_down_block
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.models.unets.unet_2d_blocks import (
CrossAttnDownBlock2D,
DownBlock2D,
UNetMidBlock2DCrossAttn,
get_down_block,
)
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from torch import nn
from invokeai.backend.util.logging import InvokeAILogger
@ -27,7 +32,7 @@ from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger(__name__)
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
"""
A ControlNet model.

View File

@ -5,14 +5,14 @@ pip install <path_to_git_source>.
import os
import platform
from distutils.version import LooseVersion
from importlib.metadata import PackageNotFoundError, distribution, distributions
import pkg_resources
import psutil
import requests
from rich import box, print
from rich.console import Console, group
from rich.panel import Panel
from rich.prompt import Prompt
from rich.prompt import Confirm, Prompt
from rich.style import Style
from invokeai.version import __version__
@ -61,6 +61,65 @@ def get_pypi_versions():
return latest_version, latest_release_candidate, versions
def get_torch_extra_index_url() -> str | None:
"""
Determine torch wheel source URL and optional modules based on the user's OS.
"""
resolved_url = None
# In all other cases (like MacOS (MPS) or Linux+CUDA), there is no need to specify the extra index URL.
torch_package_urls = {
"windows_cuda": "https://download.pytorch.org/whl/cu121",
"linux_rocm": "https://download.pytorch.org/whl/rocm5.6",
"linux_cpu": "https://download.pytorch.org/whl/cpu",
}
nvidia_packages_present = (
len([d.metadata["Name"] for d in distributions() if d.metadata["Name"].startswith("nvidia")]) > 0
)
device = "cuda" if nvidia_packages_present else None
manual_gpu_selection_prompt = (
"[bold]We tried and failed to guess your GPU capabilities[/] :thinking_face:. Please select the GPU type:"
)
if OS == "Linux":
if not device:
# do we even need to offer a CPU-only install option?
print(manual_gpu_selection_prompt)
print("1: NVIDIA (CUDA)")
print("2: AMD (ROCm)")
print("3: No GPU - CPU only")
answer = Prompt.ask("Choice:", choices=["1", "2", "3"], default="1")
match answer:
case "1":
device = "cuda"
case "2":
device = "rocm"
case "3":
device = "cpu"
if device != "cuda":
resolved_url = torch_package_urls[f"linux_{device}"]
if OS == "Windows":
if not device:
print(manual_gpu_selection_prompt)
print("1: NVIDIA (CUDA)")
print("2: No GPU - CPU only")
answer = Prompt.ask("Your choice:", choices=["1", "2"], default="1")
match answer:
case "1":
device = "cuda"
case "2":
device = "cpu"
if device == "cuda":
resolved_url = torch_package_urls[f"windows_{device}"]
return resolved_url
def welcome(latest_release: str, latest_prerelease: str):
@group()
def text():
@ -89,12 +148,11 @@ def welcome(latest_release: str, latest_prerelease: str):
def get_extras():
extras = ""
try:
_ = pkg_resources.get_distribution("xformers")
distribution("xformers")
extras = "[xformers]"
except pkg_resources.DistributionNotFound:
pass
except PackageNotFoundError:
extras = ""
return extras
@ -125,8 +183,22 @@ def main():
extras = get_extras()
console.line()
force_reinstall = Confirm.ask(
"[bold]Force reinstallation of all dependencies?[/] This [i]may[/] help fix a broken upgrade, but is usually not necessary.",
default=False,
)
console.line()
flags = []
if (index_url := get_torch_extra_index_url()) is not None:
flags.append(f"--extra-index-url {index_url}")
if force_reinstall:
flags.append("--force-reinstall")
flags = " ".join(flags)
print(f":crossed_fingers: Upgrading to [yellow]{release}[/yellow]")
cmd = f'pip install "invokeai{extras}=={release}" --use-pep517 --upgrade'
cmd = f'pip install "invokeai{extras}=={release}" --use-pep517 --upgrade {flags}'
print("")
print("")

View File

@ -1,9 +1,26 @@
module.exports = {
extends: ['@invoke-ai/eslint-config-react'],
plugins: ['path', 'i18next'],
rules: {
// TODO(psyche): Enable this rule. Requires no default exports in components - many changes.
'react-refresh/only-export-components': 'off',
// TODO(psyche): Enable this rule. Requires a lot of eslint-disable-next-line comments.
'@typescript-eslint/consistent-type-assertions': 'off',
// https://github.com/qdanik/eslint-plugin-path
'path/no-relative-imports': ['error', { maxDepth: 0 }],
// https://github.com/edvardchen/eslint-plugin-i18next/blob/HEAD/docs/rules/no-literal-string.md
'i18next/no-literal-string': 'error',
},
overrides: [
/**
* Overrides for stories
*/
{
files: ['*.stories.tsx'],
rules: {
// We may not have i18n available in stories.
'i18next/no-literal-string': 'off',
},
},
],
};

View File

@ -111,7 +111,7 @@
},
"devDependencies": {
"@arthurgeron/eslint-plugin-react-usememo": "^2.2.3",
"@invoke-ai/eslint-config-react": "^0.0.12",
"@invoke-ai/eslint-config-react": "^0.0.13",
"@invoke-ai/prettier-config-react": "^0.0.6",
"@storybook/addon-docs": "^7.6.10",
"@storybook/addon-essentials": "^7.6.10",

View File

@ -178,8 +178,8 @@ devDependencies:
specifier: ^2.2.3
version: 2.2.3
'@invoke-ai/eslint-config-react':
specifier: ^0.0.12
version: 0.0.12(@typescript-eslint/eslint-plugin@6.19.0)(@typescript-eslint/parser@6.19.0)(eslint-config-prettier@9.1.0)(eslint-plugin-i18next@6.0.3)(eslint-plugin-import@2.29.1)(eslint-plugin-react-hooks@4.6.0)(eslint-plugin-react-refresh@0.4.5)(eslint-plugin-react@7.33.2)(eslint-plugin-simple-import-sort@10.0.0)(eslint-plugin-storybook@0.6.15)(eslint-plugin-unused-imports@3.0.0)(eslint@8.56.0)
specifier: ^0.0.13
version: 0.0.13(@typescript-eslint/eslint-plugin@6.19.0)(@typescript-eslint/parser@6.19.0)(eslint-config-prettier@9.1.0)(eslint-plugin-import@2.29.1)(eslint-plugin-react-hooks@4.6.0)(eslint-plugin-react-refresh@0.4.5)(eslint-plugin-react@7.33.2)(eslint-plugin-simple-import-sort@10.0.0)(eslint-plugin-storybook@0.6.15)(eslint-plugin-unused-imports@3.0.0)(eslint@8.56.0)
'@invoke-ai/prettier-config-react':
specifier: ^0.0.6
version: 0.0.6(prettier@3.2.4)
@ -3551,14 +3551,13 @@ packages:
'@swc/helpers': 0.5.3
dev: false
/@invoke-ai/eslint-config-react@0.0.12(@typescript-eslint/eslint-plugin@6.19.0)(@typescript-eslint/parser@6.19.0)(eslint-config-prettier@9.1.0)(eslint-plugin-i18next@6.0.3)(eslint-plugin-import@2.29.1)(eslint-plugin-react-hooks@4.6.0)(eslint-plugin-react-refresh@0.4.5)(eslint-plugin-react@7.33.2)(eslint-plugin-simple-import-sort@10.0.0)(eslint-plugin-storybook@0.6.15)(eslint-plugin-unused-imports@3.0.0)(eslint@8.56.0):
resolution: {integrity: sha512-6IXENcSa7vv+YPO/TYmC8qXXJFQt3JqDY+Yc1AMf4/d3b3o+CA7/mqepXIhydG9Gqo5jTRknXdDmjSaLxgCJ/g==}
/@invoke-ai/eslint-config-react@0.0.13(@typescript-eslint/eslint-plugin@6.19.0)(@typescript-eslint/parser@6.19.0)(eslint-config-prettier@9.1.0)(eslint-plugin-import@2.29.1)(eslint-plugin-react-hooks@4.6.0)(eslint-plugin-react-refresh@0.4.5)(eslint-plugin-react@7.33.2)(eslint-plugin-simple-import-sort@10.0.0)(eslint-plugin-storybook@0.6.15)(eslint-plugin-unused-imports@3.0.0)(eslint@8.56.0):
resolution: {integrity: sha512-dfo9k+wPHdvpy1z6ABoYXR/Ttzs1FAnbC46ttIxVhZuqDq8K5cLWznivrOfl7f0hJb8Cb8HiuQb4pHDxhHBDqA==}
peerDependencies:
'@typescript-eslint/eslint-plugin': ^6.19.0
'@typescript-eslint/parser': ^6.19.0
eslint: ^8.56.0
eslint-config-prettier: ^9.1.0
eslint-plugin-i18next: ^6.0.3
eslint-plugin-import: ^2.29.1
eslint-plugin-react: ^7.33.2
eslint-plugin-react-hooks: ^4.6.0
@ -3571,7 +3570,6 @@ packages:
'@typescript-eslint/parser': 6.19.0(eslint@8.56.0)(typescript@5.3.3)
eslint: 8.56.0
eslint-config-prettier: 9.1.0(eslint@8.56.0)
eslint-plugin-i18next: 6.0.3
eslint-plugin-import: 2.29.1(@typescript-eslint/parser@6.19.0)(eslint@8.56.0)
eslint-plugin-react: 7.33.2(eslint@8.56.0)
eslint-plugin-react-hooks: 4.6.0(eslint@8.56.0)

View File

@ -98,7 +98,7 @@
"outputs": "Ausgabe",
"data": "Daten",
"safetensors": "Safetensors",
"outpaint": "outpaint",
"outpaint": "Ausmalen",
"details": "Details",
"format": "Format",
"unknown": "Unbekannt",
@ -131,7 +131,8 @@
"localSystem": "Lokales System",
"orderBy": "Ordnen nach",
"saveAs": "Speicher als",
"updated": "Aktualisiert"
"updated": "Aktualisiert",
"copy": "Kopieren"
},
"gallery": {
"generations": "Erzeugungen",
@ -161,7 +162,13 @@
"currentlyInUse": "Dieses Bild wird derzeit in den folgenden Funktionen verwendet:",
"deleteImagePermanent": "Gelöschte Bilder können nicht wiederhergestellt werden.",
"autoAssignBoardOnClick": "Board per Klick automatisch zuweisen",
"noImageSelected": "Kein Bild ausgewählt"
"noImageSelected": "Kein Bild ausgewählt",
"problemDeletingImagesDesc": "Eins oder mehr Bilder könnten nicht gelöscht werden",
"starImage": "Bild markieren",
"assets": "Ressourcen",
"unstarImage": "Markierung Entfernen",
"image": "Bild",
"deleteSelection": "Lösche markierte"
},
"hotkeys": {
"keyboardShortcuts": "Tastenkürzel",
@ -365,7 +372,13 @@
"addNodes": {
"title": "Knotenpunkt hinzufügen",
"desc": "Öffnet das Menü zum Hinzufügen von Knoten"
}
},
"cancelAndClear": {
"title": "Abbruch und leeren"
},
"noHotkeysFound": "Kein Hotkey gefunden",
"searchHotkeys": "Hotkeys durchsuchen",
"clearSearch": "Suche leeren"
},
"modelManager": {
"modelAdded": "Model hinzugefügt",
@ -832,7 +845,13 @@
"hedDescription": "Ganzheitlich verschachtelte Kantenerkennung",
"scribble": "Scribble",
"maxFaces": "Maximal Anzahl Gesichter",
"resizeSimple": "Größe ändern (einfach)"
"resizeSimple": "Größe ändern (einfach)",
"large": "Groß",
"modelSize": "Modell Größe",
"small": "Klein",
"base": "Basis",
"depthAnything": "Depth Anything",
"depthAnythingDescription": "Erstellung einer Tiefenkarte mit der Depth Anything-Technik"
},
"queue": {
"status": "Status",
@ -865,7 +884,7 @@
"item": "Auftrag",
"notReady": "Warteschlange noch nicht bereit",
"batchValues": "Stapel Werte",
"queueCountPrediction": "{{predicted}} zur Warteschlange hinzufügen",
"queueCountPrediction": "{{promptsCount}} Prompts × {{iterations}} Iterationen -> {{count}} Generationen",
"queuedCount": "{{pending}} wartenden Elemente",
"clearQueueAlertDialog": "Die Warteschlange leeren, stoppt den aktuellen Prozess und leert die Warteschlange komplett.",
"completedIn": "Fertig in",
@ -887,7 +906,9 @@
"back": "Hinten",
"resumeSucceeded": "Prozessor wieder aufgenommen",
"resumeTooltip": "Prozessor wieder aufnehmen",
"time": "Zeit"
"time": "Zeit",
"batchQueuedDesc_one": "{{count}} Eintrage ans {{direction}} der Wartschlange hinzugefügt",
"batchQueuedDesc_other": "{{count}} Einträge ans {{direction}} der Wartschlange hinzugefügt"
},
"metadata": {
"negativePrompt": "Negativ Beschreibung",
@ -956,7 +977,8 @@
"enable": "Aktivieren",
"clear": "Leeren",
"maxCacheSize": "Maximale Cache Größe",
"cacheSize": "Cache Größe"
"cacheSize": "Cache Größe",
"useCache": "Benutze Cache"
},
"embedding": {
"noMatchingEmbedding": "Keine passenden Embeddings",
@ -1042,7 +1064,8 @@
},
"compositing": {
"coherenceTab": "Kohärenzpass",
"infillTab": "Füllung"
"infillTab": "Füllung",
"title": "Compositing"
}
}
}

View File

@ -1376,6 +1376,7 @@
"problemCopyingCanvasDesc": "Unable to export base layer",
"problemCopyingImage": "Unable to Copy Image",
"problemCopyingImageLink": "Unable to Copy Image Link",
"problemDownloadingImage": "Unable to Download Image",
"problemDownloadingCanvas": "Problem Downloading Canvas",
"problemDownloadingCanvasDesc": "Unable to export base layer",
"problemImportingMask": "Problem Importing Mask",

View File

@ -0,0 +1,43 @@
import { useAppToaster } from 'app/components/Toaster';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useImageUrlToBlob } from './useImageUrlToBlob';
export const useDownloadImage = () => {
const toaster = useAppToaster();
const { t } = useTranslation();
const imageUrlToBlob = useImageUrlToBlob();
const downloadImage = useCallback(
async (image_url: string, image_name: string) => {
try {
const blob = await imageUrlToBlob(image_url);
if (!blob) {
throw new Error('Unable to create Blob');
}
const url = window.URL.createObjectURL(blob);
const a = document.createElement('a');
a.style.display = 'none';
a.href = url;
a.download = image_name;
document.body.appendChild(a);
a.click();
window.URL.revokeObjectURL(url);
} catch (err) {
toaster({
title: t('toast.problemDownloadingImage'),
description: String(err),
status: 'error',
duration: 2500,
isClosable: true,
});
}
},
[t, toaster, imageUrlToBlob]
);
return { downloadImage };
};

View File

@ -4,6 +4,7 @@ import { useAppToaster } from 'app/components/Toaster';
import { $customStarUI } from 'app/store/nanostores/customStarUI';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useCopyImageToClipboard } from 'common/hooks/useCopyImageToClipboard';
import { useDownloadImage } from 'common/hooks/useDownloadImage';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { imagesToChangeSelected, isModalOpenChanged } from 'features/changeBoardModal/store/slice';
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
@ -47,7 +48,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const toaster = useAppToaster();
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const customStarUi = useStore($customStarUI);
const { downloadImage } = useDownloadImage();
const { metadata, isLoading: isLoadingMetadata } = useDebouncedMetadata(imageDTO?.image_name);
const { getAndLoadEmbeddedWorkflow, getAndLoadEmbeddedWorkflowResult } = useGetAndLoadEmbeddedWorkflow({});
@ -143,6 +144,10 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
}
}, [unstarImages, imageDTO]);
const handleDownloadImage = useCallback(() => {
downloadImage(imageDTO.image_url, imageDTO.image_name);
}, [downloadImage, imageDTO.image_name, imageDTO.image_url]);
return (
<>
<MenuItem as="a" href={imageDTO.image_url} target="_blank" icon={<PiShareFatBold />}>
@ -153,14 +158,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
{t('parameters.copyImage')}
</MenuItem>
)}
<MenuItem
as="a"
download={true}
href={imageDTO.image_url}
target="_blank"
icon={<PiDownloadSimpleBold />}
w="100%"
>
<MenuItem icon={<PiDownloadSimpleBold />} onClickCapture={handleDownloadImage}>
{t('parameters.downloadImage')}
</MenuItem>
<MenuDivider />

View File

@ -1,13 +1,12 @@
import { ConfirmationAlertDialog, Flex, IconButton, Text, useDisclosure } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleFill } from 'react-icons/pi';
import { addToast } from '../../../../../system/store/systemSlice';
import { makeToast } from '../../../../../system/util/makeToast';
import { nodeEditorReset } from '../../../../store/nodesSlice';
const ClearFlowButton = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();

View File

@ -1,13 +1,12 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { $builtWorkflow } from 'features/nodes/hooks/useWorkflowWatcher';
import { useSaveWorkflowAsDialog } from 'features/workflowLibrary/components/SaveWorkflowAsDialog/useSaveWorkflowAsDialog';
import { isWorkflowWithID, useSaveLibraryWorkflow } from 'features/workflowLibrary/hooks/useSaveWorkflow';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiFloppyDiskBold } from 'react-icons/pi';
import { isWorkflowWithID, useSaveLibraryWorkflow } from '../../../../../workflowLibrary/hooks/useSaveWorkflow';
import { $builtWorkflow } from '../../../../hooks/useWorkflowWatcher';
const SaveWorkflowButton = () => {
const { t } = useTranslation();
const isTouched = useAppSelector((s) => s.workflow.isTouched);

View File

@ -26,7 +26,7 @@ import { ImageSizeLinear } from './ImageSizeLinear';
const selector = createMemoizedSelector(
[selectGenerationSlice, selectCanvasSlice, selectHrfSlice, activeTabNameSelector],
(generation, canvas, hrf, activeTabName) => {
const { shouldRandomizeSeed } = generation;
const { shouldRandomizeSeed, model } = generation;
const { hrfEnabled } = hrf;
const badges: string[] = [];
@ -56,7 +56,7 @@ const selector = createMemoizedSelector(
if (hrfEnabled) {
badges.push('HiRes Fix');
}
return { badges, activeTabName };
return { badges, activeTabName, isSDXL: model?.base_model === 'sdxl' };
}
);
@ -66,7 +66,7 @@ const scalingLabelProps: FormLabelProps = {
export const ImageSettingsAccordion = memo(() => {
const { t } = useTranslation();
const { badges, activeTabName } = useAppSelector(selector);
const { badges, activeTabName, isSDXL } = useAppSelector(selector);
const { isOpen: isOpenAccordion, onToggle: onToggleAccordion } = useStandaloneAccordionToggle({
id: 'image-settings',
defaultIsOpen: true,
@ -94,7 +94,7 @@ export const ImageSettingsAccordion = memo(() => {
</Flex>
{(activeTabName === 'img2img' || activeTabName === 'unifiedCanvas') && <ImageToImageStrength />}
{activeTabName === 'img2img' && <ImageToImageFit />}
{activeTabName === 'txt2img' && <HrfSettings />}
{activeTabName === 'txt2img' && !isSDXL && <HrfSettings />}
{activeTabName === 'unifiedCanvas' && (
<>
<ParamScaleBeforeProcessing />

View File

@ -13,14 +13,13 @@ import {
Input,
} from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { $workflowCategories } from 'app/store/nanostores/workflowCategories';
import { useSaveWorkflowAsDialog } from 'features/workflowLibrary/components/SaveWorkflowAsDialog/useSaveWorkflowAsDialog';
import { useSaveWorkflowAs } from 'features/workflowLibrary/hooks/useSaveWorkflowAs';
import { t } from 'i18next';
import type { ChangeEvent } from 'react';
import { useCallback, useRef } from 'react';
import { $workflowCategories } from '../../../../app/store/nanostores/workflowCategories';
import { useSaveWorkflowAs } from '../../hooks/useSaveWorkflowAs';
export const SaveWorkflowAsDialog = () => {
const { isOpen, onClose, workflowName, setWorkflowName, shouldSaveToProject, setShouldSaveToProject } =
useSaveWorkflowAsDialog();

View File

@ -1,13 +1,12 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { $builtWorkflow } from 'features/nodes/hooks/useWorkflowWatcher';
import { useSaveWorkflowAsDialog } from 'features/workflowLibrary/components/SaveWorkflowAsDialog/useSaveWorkflowAsDialog';
import { isWorkflowWithID, useSaveLibraryWorkflow } from 'features/workflowLibrary/hooks/useSaveWorkflow';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiFloppyDiskBold } from 'react-icons/pi';
import { useAppSelector } from '../../../../app/store/storeHooks';
import { $builtWorkflow } from '../../../nodes/hooks/useWorkflowWatcher';
const SaveWorkflowMenuItem = () => {
const { t } = useTranslation();
const { saveWorkflow } = useSaveLibraryWorkflow();

View File

@ -8,12 +8,11 @@ import {
workflowNameChanged,
workflowSaved,
} from 'features/nodes/store/workflowSlice';
import type { WorkflowCategory } from 'features/nodes/types/workflow';
import { useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { useCreateWorkflowMutation, workflowsApi } from 'services/api/endpoints/workflows';
import type { WorkflowCategory } from '../../nodes/types/workflow';
type SaveWorkflowAsArg = {
name: string;
category: WorkflowCategory;

View File

@ -107,7 +107,6 @@ plugins:
extra_javascript:
- https://unpkg.com/tablesort@5.3.0/dist/tablesort.min.js
- javascripts/tablesort.js
- https://widget.kapa.ai/kapa-widget.bundle.js
- javascript/init_kapa_widget.js
extra:

View File

@ -38,7 +38,7 @@ dependencies = [
"clip_anytorch==2.5.2", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel==2.0.2",
"controlnet-aux==0.0.7",
"diffusers[torch]==0.25.1",
"diffusers[torch]==0.26.0",
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
"mediapipe==0.10.7", # needed for "mediapipeface" controlnet model
"numpy==1.26.3", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()

View File

@ -2,6 +2,8 @@ import logging
import pytest
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
# This import must happen before other invoke imports or test in other files(!!) break
from .test_nodes import ( # isort: split
PromptCollectionTestInvocation,
@ -19,7 +21,6 @@ from invokeai.app.services.invocation_processor.invocation_processor_default imp
from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
from invokeai.app.services.shared.graph import (
CollectInvocation,
@ -27,8 +28,6 @@ from invokeai.app.services.shared.graph import (
GraphExecutionState,
IterateInvocation,
)
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import create_mock_sqlite_database
from .test_invoker import create_edge
@ -48,10 +47,8 @@ def simple_graph():
@pytest.fixture
def mock_services() -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(configuration, logger)
# NOTE: none of these are actually called by the test invocations
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
graph_execution_manager = ItemStorageMemory[GraphExecutionState]()
return InvocationServices(
board_image_records=None, # type: ignore
board_images=None, # type: ignore

View File

@ -3,8 +3,7 @@ import logging
import pytest
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import create_mock_sqlite_database
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
# This import must happen before other invoke imports or test in other files(!!) break
from .test_nodes import ( # isort: split
@ -22,7 +21,6 @@ from invokeai.app.services.invocation_queue.invocation_queue_memory import Memor
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation
@ -53,11 +51,6 @@ def graph_with_subgraph():
@pytest.fixture
def mock_services() -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(configuration, logger)
# NOTE: none of these are actually called by the test invocations
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
return InvocationServices(
board_image_records=None, # type: ignore
board_images=None, # type: ignore
@ -65,7 +58,7 @@ def mock_services() -> InvocationServices:
boards=None, # type: ignore
configuration=configuration,
events=TestEventService(),
graph_execution_manager=graph_execution_manager,
graph_execution_manager=ItemStorageMemory[GraphExecutionState](),
image_files=None, # type: ignore
image_records=None, # type: ignore
images=None, # type: ignore

View File

@ -1,139 +0,0 @@
import pytest
from pydantic import BaseModel, Field
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.util.logging import InvokeAILogger
class TestModel(BaseModel):
id: str = Field(description="ID")
name: str = Field(description="Name")
__test__ = False # not a pytest test case
@pytest.fixture
def db() -> SqliteItemStorage[TestModel]:
config = InvokeAIAppConfig(use_memory_db=True)
logger = InvokeAILogger.get_logger()
db_path = None if config.use_memory_db else config.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
sqlite_item_storage = SqliteItemStorage[TestModel](db=db, table_name="test", id_field="id")
return sqlite_item_storage
def test_sqlite_service_can_create_and_get(db: SqliteItemStorage[TestModel]):
db.set(TestModel(id="1", name="Test"))
assert db.get("1") == TestModel(id="1", name="Test")
def test_sqlite_service_can_list(db: SqliteItemStorage[TestModel]):
db.set(TestModel(id="1", name="Test"))
db.set(TestModel(id="2", name="Test"))
db.set(TestModel(id="3", name="Test"))
results = db.list()
assert results.page == 0
assert results.pages == 1
assert results.per_page == 10
assert results.total == 3
assert results.items == [
TestModel(id="1", name="Test"),
TestModel(id="2", name="Test"),
TestModel(id="3", name="Test"),
]
def test_sqlite_service_can_delete(db: SqliteItemStorage[TestModel]):
db.set(TestModel(id="1", name="Test"))
db.delete("1")
assert db.get("1") is None
def test_sqlite_service_calls_set_callback(db: SqliteItemStorage[TestModel]):
called = False
def on_changed(item: TestModel):
nonlocal called
called = True
db.on_changed(on_changed)
db.set(TestModel(id="1", name="Test"))
assert called
def test_sqlite_service_calls_delete_callback(db: SqliteItemStorage[TestModel]):
called = False
def on_deleted(item_id: str):
nonlocal called
called = True
db.on_deleted(on_deleted)
db.set(TestModel(id="1", name="Test"))
db.delete("1")
assert called
def test_sqlite_service_can_list_with_pagination(db: SqliteItemStorage[TestModel]):
db.set(TestModel(id="1", name="Test"))
db.set(TestModel(id="2", name="Test"))
db.set(TestModel(id="3", name="Test"))
results = db.list(page=0, per_page=2)
assert results.page == 0
assert results.pages == 2
assert results.per_page == 2
assert results.total == 3
assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")]
def test_sqlite_service_can_list_with_pagination_and_offset(db: SqliteItemStorage[TestModel]):
db.set(TestModel(id="1", name="Test"))
db.set(TestModel(id="2", name="Test"))
db.set(TestModel(id="3", name="Test"))
results = db.list(page=1, per_page=2)
assert results.page == 1
assert results.pages == 2
assert results.per_page == 2
assert results.total == 3
assert results.items == [TestModel(id="3", name="Test")]
def test_sqlite_service_can_search(db: SqliteItemStorage[TestModel]):
db.set(TestModel(id="1", name="Test"))
db.set(TestModel(id="2", name="Test"))
db.set(TestModel(id="3", name="Test"))
results = db.search(query="Test")
assert results.page == 0
assert results.pages == 1
assert results.per_page == 10
assert results.total == 3
assert results.items == [
TestModel(id="1", name="Test"),
TestModel(id="2", name="Test"),
TestModel(id="3", name="Test"),
]
def test_sqlite_service_can_search_with_pagination(db: SqliteItemStorage[TestModel]):
db.set(TestModel(id="1", name="Test"))
db.set(TestModel(id="2", name="Test"))
db.set(TestModel(id="3", name="Test"))
results = db.search(query="Test", page=0, per_page=2)
assert results.page == 0
assert results.pages == 2
assert results.per_page == 2
assert results.total == 3
assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")]
def test_sqlite_service_can_search_with_pagination_and_offset(db: SqliteItemStorage[TestModel]):
db.set(TestModel(id="1", name="Test"))
db.set(TestModel(id="2", name="Test"))
db.set(TestModel(id="3", name="Test"))
results = db.search(query="Test", page=1, per_page=2)
assert results.page == 1
assert results.pages == 2
assert results.per_page == 2
assert results.total == 3
assert results.items == [TestModel(id="3", name="Test")]

View File

@ -0,0 +1,110 @@
import re
import pytest
from pydantic import BaseModel
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
class MockItemModel(BaseModel):
id: str
value: int
@pytest.fixture
def item_storage_memory():
return ItemStorageMemory[MockItemModel]()
def test_item_storage_memory_initializes():
item_storage_memory = ItemStorageMemory()
assert item_storage_memory._items == {}
assert item_storage_memory._id_field == "id"
assert item_storage_memory._max_items == 10
item_storage_memory = ItemStorageMemory(id_field="bananas", max_items=20)
assert item_storage_memory._id_field == "bananas"
assert item_storage_memory._max_items == 20
with pytest.raises(ValueError, match=re.escape("max_items must be at least 1")):
item_storage_memory = ItemStorageMemory(max_items=0)
with pytest.raises(ValueError, match=re.escape("id_field must not be empty")):
item_storage_memory = ItemStorageMemory(id_field="")
def test_item_storage_memory_sets(item_storage_memory: ItemStorageMemory[MockItemModel]):
item_1 = MockItemModel(id="1", value=1)
item_storage_memory.set(item_1)
assert item_storage_memory._items == {"1": item_1}
item_2 = MockItemModel(id="2", value=2)
item_storage_memory.set(item_2)
assert item_storage_memory._items == {"1": item_1, "2": item_2}
# Updating value of existing item
item_2_updated = MockItemModel(id="2", value=9001)
item_storage_memory.set(item_2_updated)
assert item_storage_memory._items == {"1": item_1, "2": item_2_updated}
def test_item_storage_memory_gets(item_storage_memory: ItemStorageMemory[MockItemModel]):
item_1 = MockItemModel(id="1", value=1)
item_storage_memory.set(item_1)
item = item_storage_memory.get("1")
assert item == item_1
item_2 = MockItemModel(id="2", value=2)
item_storage_memory.set(item_2)
item = item_storage_memory.get("2")
assert item == item_2
item = item_storage_memory.get("3")
assert item is None
def test_item_storage_memory_deletes(item_storage_memory: ItemStorageMemory[MockItemModel]):
item_1 = MockItemModel(id="1", value=1)
item_2 = MockItemModel(id="2", value=2)
item_storage_memory.set(item_1)
item_storage_memory.set(item_2)
item_storage_memory.delete("2")
assert item_storage_memory._items == {"1": item_1}
def test_item_storage_memory_respects_max():
item_storage_memory = ItemStorageMemory(max_items=3)
for i in range(10):
item_storage_memory.set(MockItemModel(id=str(i), value=i))
assert item_storage_memory._items == {
"7": MockItemModel(id="7", value=7),
"8": MockItemModel(id="8", value=8),
"9": MockItemModel(id="9", value=9),
}
def test_item_storage_memory_calls_set_callback(item_storage_memory: ItemStorageMemory[MockItemModel]):
called_item = None
item = MockItemModel(id="1", value=1)
def on_changed(item: MockItemModel):
nonlocal called_item
called_item = item
item_storage_memory.on_changed(on_changed)
item_storage_memory.set(item)
assert called_item == item
def test_item_storage_memory_calls_delete_callback(item_storage_memory: ItemStorageMemory[MockItemModel]):
called_item_id = None
item = MockItemModel(id="1", value=1)
def on_deleted(item_id: str):
nonlocal called_item_id
called_item_id = item_id
item_storage_memory.on_deleted(on_deleted)
item_storage_memory.set(item)
item_storage_memory.delete("1")
assert called_item_id == "1"