mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-26 05:17:55 +00:00
feat: server-side client state persistence (#8314)
## Summary Move client state persistence from browser to server. - Add new client state persistence service to handle reading and writing client state to db & associated router. The API mirrors that of LocalStorage/IndexedDB where the set/get methods both operate on _keys_. For example, when we persist the canvas state, we send only the new canvas state to the backend - not the whole app state. - The data is very flexibly-typed as a pydantic `JsonValue`. The client is expected to handle all data parsing/validation (it must do this anyways, and does this today). - Change persistence from debounced to throttled at 2 seconds. Maybe less is OK? Trying to not hammer the server. - Add new persistence storage driver in client and use it in redux-remember. It does its best to avoid extraneous persist requests, caching the last data it persisted and noop-ing if there are no changes. - Storage driver tracks pending persist actions using ref counts (bc each slice is persisted independently). If there user navigates away from the page during a persist request, it will give them the "you may lose something if you navigate away" alert. - This "lose something" alert message is not customizable (browser security reasons). - The alert is triggered only when the user closes the tape while a persist network request is mid-flight. It's possible that the user makes a change and closes the page before we start persisting. In this case, they will lose the last 2 seconds of data. - I tried making triggering the alert when a persist was waiting to start, and it felt off. - Maybe the alert isn't even necessary. Again you'd lose 2s of data at most, probably a non issue. IMO after trying it, a subtle indicator somewhere on the page is probably less confusing/intrusive. - Fix an issue where the `redux-remember` enhancer was added _last_ in the enhancer chain, which prevented us detecting when a persist has succeeded. This required a small change to the `unserialze` utility (used during rehydration) to ensure slices enhanced with `redux-undo` are set up correctly as they are rehydrated. - Restructure the redux store code to avoid circular dependencies. I couldn't figure out how to do this without just smooshing it all into the main `store.ts` file. Oh well. Implications: - Because client state is now on the server, different browsers will have the same studio state. For example, if I start working on something in Firefox, if I switch to Chrome, I have the same client state. - Incognito windows won't do anything bc client state is server-side. - It takes a bit longer for persistence to happen thanks to the debounce, but there's now an indicator that tells you your stuff isn't saved yet. - Resetting the browser won't fix an issue with your studio state. You must use `Reset Web UI` to fix it (or otherwise hit the appropriate endpoint). It may be possible to end up in a Catch-22 where you can't click the button and get stuck w/ a borked studio - I think to think through this a bit more, might not be an issue. - It probably takes a bit longer to start up, since we need to retrieve client state over network instead of directly with browser APIs. Other notes: - We could explore adding an "incognito" mode, enabled via `invokeai.yaml` setting or maybe in the UI. This would temporarily disable persistence. Actually, I don't think this really makes sense, bc all the images would be saved to disk. - The studio state is stored in a single row in the DB. Currently, a static row ID is used to force the studio state to be a singleton. It is _possible_ to support multiple saved states. Might be a solve for app workspaces. ## Related Issues / Discussions n/a ## QA Instructions Try it out. It's pretty straightforward. Error states are the main things to test - for example, network blips. The new server-side persistence driver is the only real functional change - everything else is just kinda shuffling things around to support it. ## Merge Plan n/a ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_ - [ ] _Updated `What's New` copy (if doing a release after this PR)_
This commit is contained in:
@ -10,6 +10,7 @@ from invokeai.app.services.board_images.board_images_default import BoardImagesS
|
||||
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
|
||||
from invokeai.app.services.boards.boards_default import BoardService
|
||||
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
|
||||
from invokeai.app.services.client_state_persistence.client_state_persistence_sqlite import ClientStatePersistenceSqlite
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.download.download_default import DownloadQueueService
|
||||
from invokeai.app.services.events.events_fastapievents import FastAPIEventService
|
||||
@ -151,6 +152,7 @@ class ApiDependencies:
|
||||
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
|
||||
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
|
||||
workflow_thumbnails = WorkflowThumbnailFileStorageDisk(workflow_thumbnails_folder)
|
||||
client_state_persistence = ClientStatePersistenceSqlite(db=db)
|
||||
|
||||
services = InvocationServices(
|
||||
board_image_records=board_image_records,
|
||||
@ -181,6 +183,7 @@ class ApiDependencies:
|
||||
style_preset_records=style_preset_records,
|
||||
style_preset_image_files=style_preset_image_files,
|
||||
workflow_thumbnails=workflow_thumbnails,
|
||||
client_state_persistence=client_state_persistence,
|
||||
)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
@ -5,9 +5,9 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from fastapi import Body
|
||||
from fastapi import Body, HTTPException, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, JsonValue
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
||||
@ -173,3 +173,50 @@ async def disable_invocation_cache() -> None:
|
||||
async def get_invocation_cache_status() -> InvocationCacheStatus:
|
||||
"""Clears the invocation cache"""
|
||||
return ApiDependencies.invoker.services.invocation_cache.get_status()
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/client_state",
|
||||
operation_id="get_client_state_by_key",
|
||||
response_model=JsonValue | None,
|
||||
)
|
||||
async def get_client_state_by_key(
|
||||
key: str = Query(..., description="Key to get"),
|
||||
) -> JsonValue | None:
|
||||
"""Gets the client state"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.client_state_persistence.get_by_key(key)
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting client state: {e}")
|
||||
raise HTTPException(status_code=500, detail="Error setting client state")
|
||||
|
||||
|
||||
@app_router.post(
|
||||
"/client_state",
|
||||
operation_id="set_client_state",
|
||||
response_model=None,
|
||||
)
|
||||
async def set_client_state(
|
||||
key: str = Query(..., description="Key to set"),
|
||||
value: JsonValue = Body(..., description="Value of the key"),
|
||||
) -> None:
|
||||
"""Sets the client state"""
|
||||
try:
|
||||
ApiDependencies.invoker.services.client_state_persistence.set_by_key(key, value)
|
||||
except Exception as e:
|
||||
logging.error(f"Error setting client state: {e}")
|
||||
raise HTTPException(status_code=500, detail="Error setting client state")
|
||||
|
||||
|
||||
@app_router.delete(
|
||||
"/client_state",
|
||||
operation_id="delete_client_state",
|
||||
responses={204: {"description": "Client state deleted"}},
|
||||
)
|
||||
async def delete_client_state() -> None:
|
||||
"""Deletes the client state"""
|
||||
try:
|
||||
ApiDependencies.invoker.services.client_state_persistence.delete()
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting client state: {e}")
|
||||
raise HTTPException(status_code=500, detail="Error deleting client state")
|
||||
|
@ -0,0 +1,35 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
|
||||
class ClientStatePersistenceABC(ABC):
|
||||
"""
|
||||
Base class for client persistence implementations.
|
||||
This class defines the interface for persisting client data.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_by_key(self, key: str, value: JsonValue) -> None:
|
||||
"""
|
||||
Store the data for the client.
|
||||
|
||||
:param data: The client data to be stored.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_by_key(self, key: str) -> JsonValue | None:
|
||||
"""
|
||||
Get the data for the client.
|
||||
|
||||
:return: The client data.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self) -> None:
|
||||
"""
|
||||
Delete the data for the client.
|
||||
"""
|
||||
pass
|
@ -0,0 +1,65 @@
|
||||
import json
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
class ClientStatePersistenceSqlite(ClientStatePersistenceABC):
|
||||
"""
|
||||
Base class for client persistence implementations.
|
||||
This class defines the interface for persisting client data.
|
||||
"""
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._default_row_id = 1
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
|
||||
def set_by_key(self, key: str, value: JsonValue) -> None:
|
||||
state = self.get() or {}
|
||||
state.update({key: value})
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
f"""
|
||||
INSERT INTO client_state (id, data)
|
||||
VALUES ({self._default_row_id}, ?)
|
||||
ON CONFLICT(id) DO UPDATE
|
||||
SET data = excluded.data;
|
||||
""",
|
||||
(json.dumps(state),),
|
||||
)
|
||||
|
||||
def get(self) -> dict[str, JsonValue] | None:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT data FROM client_state
|
||||
WHERE id = {self._default_row_id}
|
||||
"""
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return json.loads(row[0])
|
||||
|
||||
def get_by_key(self, key: str) -> JsonValue | None:
|
||||
state = self.get()
|
||||
if state is None:
|
||||
return None
|
||||
return state.get(key, None)
|
||||
|
||||
def delete(self) -> None:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
f"""
|
||||
DELETE FROM client_state
|
||||
WHERE id = {self._default_row_id}
|
||||
"""
|
||||
)
|
@ -17,6 +17,7 @@ if TYPE_CHECKING:
|
||||
from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase
|
||||
from invokeai.app.services.boards.boards_base import BoardServiceABC
|
||||
from invokeai.app.services.bulk_download.bulk_download_base import BulkDownloadBase
|
||||
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
@ -73,6 +74,7 @@ class InvocationServices:
|
||||
style_preset_records: "StylePresetRecordsStorageBase",
|
||||
style_preset_image_files: "StylePresetImageFileStorageBase",
|
||||
workflow_thumbnails: "WorkflowThumbnailServiceBase",
|
||||
client_state_persistence: "ClientStatePersistenceABC",
|
||||
):
|
||||
self.board_images = board_images
|
||||
self.board_image_records = board_image_records
|
||||
@ -102,3 +104,4 @@ class InvocationServices:
|
||||
self.style_preset_records = style_preset_records
|
||||
self.style_preset_image_files = style_preset_image_files
|
||||
self.workflow_thumbnails = workflow_thumbnails
|
||||
self.client_state_persistence = client_state_persistence
|
||||
|
@ -23,6 +23,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_17 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_18 import build_migration_18
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_19 import build_migration_19
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_20 import build_migration_20
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_21 import build_migration_21
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@ -63,6 +64,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_18())
|
||||
migrator.register_migration(build_migration_19(app_config=config))
|
||||
migrator.register_migration(build_migration_20())
|
||||
migrator.register_migration(build_migration_21())
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
@ -0,0 +1,40 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration21Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE client_state (
|
||||
id INTEGER PRIMARY KEY CHECK(id = 1),
|
||||
data TEXT NOT NULL, -- Frontend will handle the shape of this data
|
||||
updated_at DATETIME NOT NULL DEFAULT (CURRENT_TIMESTAMP)
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TRIGGER tg_client_state_updated_at
|
||||
AFTER UPDATE ON client_state
|
||||
FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE client_state
|
||||
SET updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = OLD.id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def build_migration_21() -> Migration:
|
||||
"""Builds the migration object for migrating from version 20 to version 21. This includes:
|
||||
- Creating the `client_state` table.
|
||||
- Adding a trigger to update the `updated_at` field on updates.
|
||||
"""
|
||||
return Migration(
|
||||
from_version=20,
|
||||
to_version=21,
|
||||
callback=Migration21Callback(),
|
||||
)
|
3
invokeai/frontend/web/.gitignore
vendored
3
invokeai/frontend/web/.gitignore
vendored
@ -44,4 +44,5 @@ yalc.lock
|
||||
|
||||
# vitest
|
||||
tsconfig.vitest-temp.json
|
||||
coverage/
|
||||
coverage/
|
||||
*.tgz
|
||||
|
@ -26,7 +26,7 @@ i18n.use(initReactI18next).init({
|
||||
returnNull: false,
|
||||
});
|
||||
|
||||
const store = createStore(undefined, false);
|
||||
const store = createStore({ driver: { getItem: () => {}, setItem: () => {} }, persistThrottle: 2000 });
|
||||
$store.set(store);
|
||||
$baseUrl.set('http://localhost:9090');
|
||||
|
||||
|
@ -197,6 +197,10 @@ export default [
|
||||
importNames: ['isEqual'],
|
||||
message: 'Please use objectEquals from @observ33r/object-equals instead.',
|
||||
},
|
||||
{
|
||||
name: 'zod/v3',
|
||||
message: 'Import from zod instead.',
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
|
@ -63,7 +63,6 @@
|
||||
"framer-motion": "^11.10.0",
|
||||
"i18next": "^25.3.2",
|
||||
"i18next-http-backend": "^3.0.2",
|
||||
"idb-keyval": "6.2.2",
|
||||
"jsondiffpatch": "^0.7.3",
|
||||
"konva": "^9.3.22",
|
||||
"linkify-react": "^4.3.1",
|
||||
@ -103,7 +102,7 @@
|
||||
"use-debounce": "^10.0.5",
|
||||
"use-device-pixel-ratio": "^1.1.2",
|
||||
"uuid": "^11.1.0",
|
||||
"zod": "^4.0.5",
|
||||
"zod": "^4.0.10",
|
||||
"zod-validation-error": "^3.5.2"
|
||||
},
|
||||
"peerDependencies": {
|
||||
|
32
invokeai/frontend/web/pnpm-lock.yaml
generated
32
invokeai/frontend/web/pnpm-lock.yaml
generated
@ -80,9 +80,6 @@ importers:
|
||||
i18next-http-backend:
|
||||
specifier: ^3.0.2
|
||||
version: 3.0.2
|
||||
idb-keyval:
|
||||
specifier: 6.2.2
|
||||
version: 6.2.2
|
||||
jsondiffpatch:
|
||||
specifier: ^0.7.3
|
||||
version: 0.7.3
|
||||
@ -201,11 +198,11 @@ importers:
|
||||
specifier: ^11.1.0
|
||||
version: 11.1.0
|
||||
zod:
|
||||
specifier: ^4.0.5
|
||||
version: 4.0.5
|
||||
specifier: ^4.0.10
|
||||
version: 4.0.10
|
||||
zod-validation-error:
|
||||
specifier: ^3.5.2
|
||||
version: 3.5.3(zod@4.0.5)
|
||||
version: 3.5.3(zod@4.0.10)
|
||||
devDependencies:
|
||||
'@eslint/js':
|
||||
specifier: ^9.31.0
|
||||
@ -411,6 +408,10 @@ packages:
|
||||
resolution: {integrity: sha512-vbavdySgbTTrmFE+EsiqUTzlOr5bzlnJtUv9PynGCAKvfQqjIXbvFdumPM/GxMDfyuGMJaJAU6TO4zc1Jf1i8Q==}
|
||||
engines: {node: '>=6.9.0'}
|
||||
|
||||
'@babel/runtime@7.28.2':
|
||||
resolution: {integrity: sha512-KHp2IflsnGywDjBWDkR9iEqiWSpc8GIi0lgTT3mOElT0PP1tG26P4tmFI2YvAdzgq9RGyoHZQEIEdZy6Ec5xCA==}
|
||||
engines: {node: '>=6.9.0'}
|
||||
|
||||
'@babel/template@7.27.2':
|
||||
resolution: {integrity: sha512-LPDZ85aEJyYSd18/DkjNh4/y1ntkE5KwUHWTiqgRxruuZL2F1yuHligVHLvcHY2vMHXttKFpJn6LwfI7cw7ODw==}
|
||||
engines: {node: '>=6.9.0'}
|
||||
@ -2771,9 +2772,6 @@ packages:
|
||||
typescript:
|
||||
optional: true
|
||||
|
||||
idb-keyval@6.2.2:
|
||||
resolution: {integrity: sha512-yjD9nARJ/jb1g+CvD0tlhUHOrJ9Sy0P8T9MF3YaLlHnSRpwPfpTX0XIvpmw3gAJUmEu3FiICLBDPXVwyEvrleg==}
|
||||
|
||||
ieee754@1.2.1:
|
||||
resolution: {integrity: sha512-dcyqhDvX1C46lXZcVqCpK+FtMRQVdIMN6/Df5js2zouUsqG7I6sFxitIC+7KYK29KdXOLHdu9zL4sFnoVQnqaA==}
|
||||
|
||||
@ -4511,8 +4509,8 @@ packages:
|
||||
zod@3.25.76:
|
||||
resolution: {integrity: sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==}
|
||||
|
||||
zod@4.0.5:
|
||||
resolution: {integrity: sha512-/5UuuRPStvHXu7RS+gmvRf4NXrNxpSllGwDnCBcJZtQsKrviYXm54yDGV2KYNLT5kq0lHGcl7lqWJLgSaG+tgA==}
|
||||
zod@4.0.10:
|
||||
resolution: {integrity: sha512-3vB+UU3/VmLL2lvwcY/4RV2i9z/YU0DTV/tDuYjrwmx5WeJ7hwy+rGEEx8glHp6Yxw7ibRbKSaIFBgReRPe5KA==}
|
||||
|
||||
zustand@4.5.7:
|
||||
resolution: {integrity: sha512-CHOUy7mu3lbD6o6LJLfllpjkzhHXSBlX8B9+qPddUsIfeF5S/UZ5q0kmCsnRqT1UHFQZchNFDDzMbQsuesHWlw==}
|
||||
@ -4633,6 +4631,8 @@ snapshots:
|
||||
|
||||
'@babel/runtime@7.27.6': {}
|
||||
|
||||
'@babel/runtime@7.28.2': {}
|
||||
|
||||
'@babel/template@7.27.2':
|
||||
dependencies:
|
||||
'@babel/code-frame': 7.27.1
|
||||
@ -5736,7 +5736,7 @@ snapshots:
|
||||
'@testing-library/dom@10.4.0':
|
||||
dependencies:
|
||||
'@babel/code-frame': 7.27.1
|
||||
'@babel/runtime': 7.27.6
|
||||
'@babel/runtime': 7.28.2
|
||||
'@types/aria-query': 5.0.4
|
||||
aria-query: 5.3.0
|
||||
chalk: 4.1.2
|
||||
@ -7266,8 +7266,6 @@ snapshots:
|
||||
optionalDependencies:
|
||||
typescript: 5.8.3
|
||||
|
||||
idb-keyval@6.2.2: {}
|
||||
|
||||
ieee754@1.2.1: {}
|
||||
|
||||
ignore@5.3.2: {}
|
||||
@ -9062,13 +9060,13 @@ snapshots:
|
||||
dependencies:
|
||||
zod: 3.25.76
|
||||
|
||||
zod-validation-error@3.5.3(zod@4.0.5):
|
||||
zod-validation-error@3.5.3(zod@4.0.10):
|
||||
dependencies:
|
||||
zod: 4.0.5
|
||||
zod: 4.0.10
|
||||
|
||||
zod@3.25.76: {}
|
||||
|
||||
zod@4.0.5: {}
|
||||
zod@4.0.10: {}
|
||||
|
||||
zustand@4.5.7(@types/react@18.3.23)(immer@10.1.1)(react@18.3.1):
|
||||
dependencies:
|
||||
|
@ -2,10 +2,10 @@ import { Box } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { GlobalHookIsolator } from 'app/components/GlobalHookIsolator';
|
||||
import { GlobalModalIsolator } from 'app/components/GlobalModalIsolator';
|
||||
import { useClearStorage } from 'app/contexts/clear-storage-context';
|
||||
import { $didStudioInit, type StudioInitAction } from 'app/hooks/useStudioInitAction';
|
||||
import type { PartialAppConfig } from 'app/types/invokeai';
|
||||
import Loading from 'common/components/Loading/Loading';
|
||||
import { useClearStorage } from 'common/hooks/useClearStorage';
|
||||
import { AppContent } from 'features/ui/components/AppContent';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { ErrorBoundary } from 'react-error-boundary';
|
||||
|
@ -1,10 +1,12 @@
|
||||
import 'i18n';
|
||||
|
||||
import type { Middleware } from '@reduxjs/toolkit';
|
||||
import { ClearStorageProvider } from 'app/contexts/clear-storage-context';
|
||||
import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
|
||||
import { $didStudioInit } from 'app/hooks/useStudioInitAction';
|
||||
import type { LoggingOverrides } from 'app/logging/logger';
|
||||
import { $loggingOverrides, configureLogging } from 'app/logging/logger';
|
||||
import { buildStorageApi } from 'app/store/enhancers/reduxRemember/driver';
|
||||
import { $accountSettingsLink } from 'app/store/nanostores/accountSettingsLink';
|
||||
import { $authToken } from 'app/store/nanostores/authToken';
|
||||
import { $baseUrl } from 'app/store/nanostores/baseUrl';
|
||||
@ -70,6 +72,14 @@ interface Props extends PropsWithChildren {
|
||||
* If provided, overrides in-app navigation to the model manager
|
||||
*/
|
||||
onClickGoToModelManager?: () => void;
|
||||
storageConfig?: {
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
getItem: (key: string) => Promise<any>;
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
setItem: (key: string, value: any) => Promise<any>;
|
||||
clear: () => Promise<void>;
|
||||
persistThrottle: number;
|
||||
};
|
||||
}
|
||||
|
||||
const InvokeAIUI = ({
|
||||
@ -96,6 +106,7 @@ const InvokeAIUI = ({
|
||||
loggingOverrides,
|
||||
onClickGoToModelManager,
|
||||
whatsNew,
|
||||
storageConfig,
|
||||
}: Props) => {
|
||||
useLayoutEffect(() => {
|
||||
/*
|
||||
@ -308,9 +319,21 @@ const InvokeAIUI = ({
|
||||
};
|
||||
}, [isDebugging]);
|
||||
|
||||
const storage = useMemo(() => buildStorageApi(storageConfig), [storageConfig]);
|
||||
|
||||
useEffect(() => {
|
||||
const storageCleanup = storage.registerListeners();
|
||||
return () => {
|
||||
storageCleanup();
|
||||
};
|
||||
}, [storage]);
|
||||
|
||||
const store = useMemo(() => {
|
||||
return createStore(projectId);
|
||||
}, [projectId]);
|
||||
return createStore({
|
||||
driver: storage.reduxRememberDriver,
|
||||
persistThrottle: storageConfig?.persistThrottle ?? 2000,
|
||||
});
|
||||
}, [storage.reduxRememberDriver, storageConfig?.persistThrottle]);
|
||||
|
||||
useEffect(() => {
|
||||
$store.set(store);
|
||||
@ -327,11 +350,13 @@ const InvokeAIUI = ({
|
||||
|
||||
return (
|
||||
<React.StrictMode>
|
||||
<Provider store={store}>
|
||||
<React.Suspense fallback={<Loading />}>
|
||||
<App config={config} studioInitAction={studioInitAction} />
|
||||
</React.Suspense>
|
||||
</Provider>
|
||||
<ClearStorageProvider value={storage.clearStorage}>
|
||||
<Provider store={store}>
|
||||
<React.Suspense fallback={<Loading />}>
|
||||
<App config={config} studioInitAction={studioInitAction} />
|
||||
</React.Suspense>
|
||||
</Provider>
|
||||
</ClearStorageProvider>
|
||||
</React.StrictMode>
|
||||
);
|
||||
};
|
||||
|
@ -0,0 +1,10 @@
|
||||
import { createContext, useContext } from 'react';
|
||||
|
||||
const ClearStorageContext = createContext<() => void>(() => {});
|
||||
|
||||
export const ClearStorageProvider = ClearStorageContext.Provider;
|
||||
|
||||
export const useClearStorage = () => {
|
||||
const context = useContext(ClearStorageContext);
|
||||
return context;
|
||||
};
|
@ -1,3 +1,2 @@
|
||||
export const STORAGE_PREFIX = '@@invokeai-';
|
||||
export const EMPTY_ARRAY = [];
|
||||
export const EMPTY_OBJECT = {};
|
||||
|
@ -1,40 +1,243 @@
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { StorageError } from 'app/store/enhancers/reduxRemember/errors';
|
||||
import { $projectId } from 'app/store/nanostores/projectId';
|
||||
import type { UseStore } from 'idb-keyval';
|
||||
import { clear, createStore as createIDBKeyValStore, get, set } from 'idb-keyval';
|
||||
import { atom } from 'nanostores';
|
||||
import type { Driver } from 'redux-remember';
|
||||
import type { Driver as ReduxRememberDriver } from 'redux-remember';
|
||||
import { getBaseUrl } from 'services/api';
|
||||
import { buildAppInfoUrl } from 'services/api/endpoints/appInfo';
|
||||
|
||||
// Create a custom idb-keyval store (just needed to customize the name)
|
||||
const $idbKeyValStore = atom<UseStore>(createIDBKeyValStore('invoke', 'invoke-store'));
|
||||
const log = logger('system');
|
||||
|
||||
export const clearIdbKeyValStore = () => {
|
||||
clear($idbKeyValStore.get());
|
||||
const buildOSSServerBackedDriver = (): {
|
||||
reduxRememberDriver: ReduxRememberDriver;
|
||||
clearStorage: () => Promise<void>;
|
||||
registerListeners: () => () => void;
|
||||
} => {
|
||||
// Persistence happens per slice. To track when persistence is in progress, maintain a ref count, incrementing
|
||||
// it when a slice is being persisted and decrementing it when the persistence is done.
|
||||
let persistRefCount = 0;
|
||||
|
||||
// Keep track of the last persisted state for each key to avoid unnecessary network requests.
|
||||
//
|
||||
// `redux-remember` persists individual slices of state, so we can implicity denylist a slice by not giving it a
|
||||
// persist config.
|
||||
//
|
||||
// However, we may need to avoid persisting individual _fields_ of a slice. `redux-remember` does not provide a
|
||||
// way to do this directly.
|
||||
//
|
||||
// To accomplish this, we add a layer of logic on top of the `redux-remember`. In the state serializer function
|
||||
// provided to `redux-remember`, we can omit certain fields from the state that we do not want to persist. See
|
||||
// the implementation in `store.ts` for this logic.
|
||||
//
|
||||
// This logic is unknown to `redux-remember`. When an omitted field changes, it will still attempt to persist the
|
||||
// whole slice, even if the final, _serialized_ slice value is unchanged.
|
||||
//
|
||||
// To avoid unnecessary network requests, we keep track of the last persisted state for each key. If the value to
|
||||
// be persisted is the same as the last persisted value, we can skip the network request.
|
||||
const lastPersistedState = new Map<string, unknown>();
|
||||
|
||||
const getUrl = (key?: string) => {
|
||||
const baseUrl = getBaseUrl();
|
||||
const query: Record<string, string> = {};
|
||||
if (key) {
|
||||
query['key'] = key;
|
||||
}
|
||||
const path = buildAppInfoUrl('client_state', query);
|
||||
const url = `${baseUrl}/${path}`;
|
||||
return url;
|
||||
};
|
||||
|
||||
const reduxRememberDriver: ReduxRememberDriver = {
|
||||
getItem: async (key) => {
|
||||
try {
|
||||
const url = getUrl(key);
|
||||
const res = await fetch(url, { method: 'GET' });
|
||||
if (!res.ok) {
|
||||
throw new Error(`Response status: ${res.status}`);
|
||||
}
|
||||
const text = await res.text();
|
||||
if (!lastPersistedState.get(key)) {
|
||||
lastPersistedState.set(key, text);
|
||||
}
|
||||
return JSON.parse(text);
|
||||
} catch (originalError) {
|
||||
throw new StorageError({
|
||||
key,
|
||||
projectId: $projectId.get(),
|
||||
originalError,
|
||||
});
|
||||
}
|
||||
},
|
||||
setItem: async (key, value) => {
|
||||
try {
|
||||
persistRefCount++;
|
||||
if (lastPersistedState.get(key) === value) {
|
||||
log.trace(`Skipping persist for key "${key}" as value is unchanged.`);
|
||||
return value;
|
||||
}
|
||||
const url = getUrl(key);
|
||||
const headers = new Headers({
|
||||
'Content-Type': 'application/json',
|
||||
});
|
||||
const res = await fetch(url, { method: 'POST', headers, body: value });
|
||||
if (!res.ok) {
|
||||
throw new Error(`Response status: ${res.status}`);
|
||||
}
|
||||
|
||||
lastPersistedState.set(key, value);
|
||||
return value;
|
||||
} catch (originalError) {
|
||||
throw new StorageError({
|
||||
key,
|
||||
value,
|
||||
projectId: $projectId.get(),
|
||||
originalError,
|
||||
});
|
||||
} finally {
|
||||
persistRefCount--;
|
||||
if (persistRefCount < 0) {
|
||||
log.trace('Persist ref count is negative, resetting to 0');
|
||||
persistRefCount = 0;
|
||||
}
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
const clearStorage = async () => {
|
||||
try {
|
||||
persistRefCount++;
|
||||
const url = getUrl();
|
||||
const res = await fetch(url, { method: 'DELETE' });
|
||||
if (!res.ok) {
|
||||
throw new Error(`Response status: ${res.status}`);
|
||||
}
|
||||
} catch {
|
||||
log.error('Failed to reset client state');
|
||||
} finally {
|
||||
persistRefCount--;
|
||||
lastPersistedState.clear();
|
||||
if (persistRefCount < 0) {
|
||||
log.trace('Persist ref count is negative, resetting to 0');
|
||||
persistRefCount = 0;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const registerListeners = () => {
|
||||
const onBeforeUnload = (e: BeforeUnloadEvent) => {
|
||||
if (persistRefCount > 0) {
|
||||
e.preventDefault();
|
||||
}
|
||||
};
|
||||
window.addEventListener('beforeunload', onBeforeUnload);
|
||||
|
||||
return () => {
|
||||
window.removeEventListener('beforeunload', onBeforeUnload);
|
||||
};
|
||||
};
|
||||
|
||||
return { reduxRememberDriver, clearStorage, registerListeners };
|
||||
};
|
||||
|
||||
// Create redux-remember driver, wrapping idb-keyval
|
||||
export const idbKeyValDriver: Driver = {
|
||||
getItem: (key) => {
|
||||
const buildCustomDriver = (api: {
|
||||
getItem: (key: string) => Promise<any>;
|
||||
setItem: (key: string, value: any) => Promise<any>;
|
||||
clear: () => Promise<void>;
|
||||
}): {
|
||||
reduxRememberDriver: ReduxRememberDriver;
|
||||
clearStorage: () => Promise<void>;
|
||||
registerListeners: () => () => void;
|
||||
} => {
|
||||
// See the comment in `buildOSSServerBackedDriver` for an explanation of this variable.
|
||||
let persistRefCount = 0;
|
||||
|
||||
// See the comment in `buildOSSServerBackedDriver` for an explanation of this variable.
|
||||
const lastPersistedState = new Map<string, unknown>();
|
||||
|
||||
const reduxRememberDriver: ReduxRememberDriver = {
|
||||
getItem: async (key) => {
|
||||
try {
|
||||
log.trace(`Getting client state for key "${key}"`);
|
||||
return await api.getItem(key);
|
||||
} catch (originalError) {
|
||||
throw new StorageError({
|
||||
key,
|
||||
projectId: $projectId.get(),
|
||||
originalError,
|
||||
});
|
||||
}
|
||||
},
|
||||
setItem: async (key, value) => {
|
||||
try {
|
||||
persistRefCount++;
|
||||
|
||||
if (lastPersistedState.get(key) === value) {
|
||||
log.trace(`Skipping setting client state for key "${key}" as value is unchanged`);
|
||||
return value;
|
||||
}
|
||||
log.trace(`Setting client state for key "${key}", ${value}`);
|
||||
await api.setItem(key, value);
|
||||
lastPersistedState.set(key, value);
|
||||
return value;
|
||||
} catch (originalError) {
|
||||
throw new StorageError({
|
||||
key,
|
||||
value,
|
||||
projectId: $projectId.get(),
|
||||
originalError,
|
||||
});
|
||||
} finally {
|
||||
persistRefCount--;
|
||||
if (persistRefCount < 0) {
|
||||
log.trace('Persist ref count is negative, resetting to 0');
|
||||
persistRefCount = 0;
|
||||
}
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
const clearStorage = async () => {
|
||||
try {
|
||||
return get(key, $idbKeyValStore.get());
|
||||
} catch (originalError) {
|
||||
throw new StorageError({
|
||||
key,
|
||||
projectId: $projectId.get(),
|
||||
originalError,
|
||||
});
|
||||
persistRefCount++;
|
||||
log.trace('Clearing client state');
|
||||
await api.clear();
|
||||
} catch {
|
||||
log.error('Failed to clear client state');
|
||||
} finally {
|
||||
persistRefCount--;
|
||||
lastPersistedState.clear();
|
||||
if (persistRefCount < 0) {
|
||||
log.trace('Persist ref count is negative, resetting to 0');
|
||||
persistRefCount = 0;
|
||||
}
|
||||
}
|
||||
},
|
||||
setItem: (key, value) => {
|
||||
try {
|
||||
return set(key, value, $idbKeyValStore.get());
|
||||
} catch (originalError) {
|
||||
throw new StorageError({
|
||||
key,
|
||||
value,
|
||||
projectId: $projectId.get(),
|
||||
originalError,
|
||||
});
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
const registerListeners = () => {
|
||||
const onBeforeUnload = (e: BeforeUnloadEvent) => {
|
||||
if (persistRefCount > 0) {
|
||||
e.preventDefault();
|
||||
}
|
||||
};
|
||||
window.addEventListener('beforeunload', onBeforeUnload);
|
||||
|
||||
return () => {
|
||||
window.removeEventListener('beforeunload', onBeforeUnload);
|
||||
};
|
||||
};
|
||||
|
||||
return { reduxRememberDriver, clearStorage, registerListeners };
|
||||
};
|
||||
|
||||
export const buildStorageApi = (api?: {
|
||||
getItem: (key: string) => Promise<any>;
|
||||
setItem: (key: string, value: any) => Promise<any>;
|
||||
clear: () => Promise<void>;
|
||||
}) => {
|
||||
if (api) {
|
||||
return buildCustomDriver(api);
|
||||
} else {
|
||||
return buildOSSServerBackedDriver();
|
||||
}
|
||||
};
|
||||
|
@ -1,73 +0,0 @@
|
||||
import type { TypedStartListening } from '@reduxjs/toolkit';
|
||||
import { addListener, createListenerMiddleware } from '@reduxjs/toolkit';
|
||||
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
|
||||
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
|
||||
import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived';
|
||||
import { addAppStartedListener } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
|
||||
import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/batchEnqueued';
|
||||
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
|
||||
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
|
||||
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
|
||||
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
|
||||
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
|
||||
import { addImageRemovedFromBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard';
|
||||
import { addImageUploadedFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageUploaded';
|
||||
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
|
||||
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
|
||||
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
|
||||
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
|
||||
import { addArchivedOrDeletedBoardListener } from './listeners/addArchivedOrDeletedBoardListener';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
|
||||
|
||||
const startAppListening = listenerMiddleware.startListening as AppStartListening;
|
||||
|
||||
export const addAppListener = addListener.withTypes<RootState, AppDispatch>();
|
||||
|
||||
/**
|
||||
* The RTK listener middleware is a lightweight alternative sagas/observables.
|
||||
*
|
||||
* Most side effect logic should live in a listener.
|
||||
*/
|
||||
|
||||
// Image uploaded
|
||||
addImageUploadedFulfilledListener(startAppListening);
|
||||
|
||||
// Image deleted
|
||||
addDeleteBoardAndImagesFulfilledListener(startAppListening);
|
||||
|
||||
// User Invoked
|
||||
addAnyEnqueuedListener(startAppListening);
|
||||
addBatchEnqueuedListener(startAppListening);
|
||||
|
||||
// Socket.IO
|
||||
addSocketConnectedEventListener(startAppListening);
|
||||
|
||||
// Gallery bulk download
|
||||
addBulkDownloadListeners(startAppListening);
|
||||
|
||||
// Boards
|
||||
addImageAddedToBoardFulfilledListener(startAppListening);
|
||||
addImageRemovedFromBoardFulfilledListener(startAppListening);
|
||||
addBoardIdSelectedListener(startAppListening);
|
||||
addArchivedOrDeletedBoardListener(startAppListening);
|
||||
|
||||
// Node schemas
|
||||
addGetOpenAPISchemaListener(startAppListening);
|
||||
|
||||
// Models
|
||||
addModelSelectedListener(startAppListening);
|
||||
|
||||
// app startup
|
||||
addAppStartedListener(startAppListening);
|
||||
addModelsLoadedListener(startAppListening);
|
||||
addAppConfigReceivedListener(startAppListening);
|
||||
|
||||
// Ad-hoc upscale workflwo
|
||||
addAdHocPostProcessingRequestedListener(startAppListening);
|
||||
|
||||
addSetDefaultSettingsListener(startAppListening);
|
@ -1,6 +1,6 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppStartListening } from 'app/store/store';
|
||||
import { buildAdHocPostProcessingGraph } from 'features/nodes/util/graph/buildAdHocPostProcessingGraph';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppStartListening } from 'app/store/store';
|
||||
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import {
|
||||
autoAddBoardIdChanged,
|
||||
|
@ -1,4 +1,4 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppStartListening } from 'app/store/store';
|
||||
import { queueApi, selectQueueStatus } from 'services/api/endpoints/queue';
|
||||
|
||||
export const addAnyEnqueuedListener = (startAppListening: AppStartListening) => {
|
||||
|
@ -1,4 +1,4 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppStartListening } from 'app/store/store';
|
||||
import { setInfillMethod } from 'features/controlLayers/store/paramsSlice';
|
||||
import { shouldUseNSFWCheckerChanged, shouldUseWatermarkerChanged } from 'features/system/store/systemSlice';
|
||||
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppStartListening } from 'app/store/store';
|
||||
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppStartListening } from 'app/store/store';
|
||||
import { truncate } from 'es-toolkit/compat';
|
||||
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
|
||||
import { toast } from 'features/toast/toast';
|
||||
|
@ -1,4 +1,4 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppStartListening } from 'app/store/store';
|
||||
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { getImageUsage } from 'features/deleteImageModal/store/state';
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppStartListening } from 'app/store/store';
|
||||
import { selectGetImageNamesQueryArgs, selectSelectedBoardId } from 'features/gallery/store/gallerySelectors';
|
||||
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppStartListening } from 'app/store/store';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppStartListening } from 'app/store/store';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { size } from 'es-toolkit/compat';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppStartListening } from 'app/store/store';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
|
||||
const log = logger('gallery');
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppStartListening } from 'app/store/store';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
|
||||
const log = logger('gallery');
|
||||
|
@ -1,7 +1,6 @@
|
||||
import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { AppStartListening, RootState } from 'app/store/store';
|
||||
import { omit } from 'es-toolkit/compat';
|
||||
import { imageUploadedClientSide } from 'features/gallery/store/actions';
|
||||
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppStartListening } from 'app/store/store';
|
||||
import { bboxSyncedToOptimalDimension, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
|
||||
import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
|
||||
|
@ -1,6 +1,5 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
import type { AppDispatch, AppStartListening, RootState } from 'app/store/store';
|
||||
import { controlLayerModelChanged, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
|
||||
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
|
||||
import {
|
||||
|
@ -1,4 +1,4 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppStartListening } from 'app/store/store';
|
||||
import { isNil } from 'es-toolkit';
|
||||
import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasSlice';
|
||||
import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
|
@ -1,8 +1,8 @@
|
||||
import { objectEquals } from '@observ33r/object-equals';
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { $baseUrl } from 'app/store/nanostores/baseUrl';
|
||||
import type { AppStartListening } from 'app/store/store';
|
||||
import { atom } from 'nanostores';
|
||||
import { api } from 'services/api';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
|
@ -1,35 +1,46 @@
|
||||
import type { ThunkDispatch, UnknownAction } from '@reduxjs/toolkit';
|
||||
import { autoBatchEnhancer, combineReducers, configureStore } from '@reduxjs/toolkit';
|
||||
import type { ThunkDispatch, TypedStartListening, UnknownAction } from '@reduxjs/toolkit';
|
||||
import { addListener, combineReducers, configureStore, createListenerMiddleware } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { idbKeyValDriver } from 'app/store/enhancers/reduxRemember/driver';
|
||||
import { errorHandler } from 'app/store/enhancers/reduxRemember/errors';
|
||||
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
|
||||
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
|
||||
import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived';
|
||||
import { addAppStartedListener } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
|
||||
import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/batchEnqueued';
|
||||
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
|
||||
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
|
||||
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
|
||||
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
|
||||
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
|
||||
import { addImageRemovedFromBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard';
|
||||
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
|
||||
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
|
||||
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
|
||||
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { keys, mergeWith, omit, pick } from 'es-toolkit/compat';
|
||||
import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice';
|
||||
import { canvasSettingsPersistConfig, canvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { canvasPersistConfig, canvasSlice, canvasUndoableConfig } from 'features/controlLayers/store/canvasSlice';
|
||||
import {
|
||||
canvasSessionSlice,
|
||||
canvasStagingAreaPersistConfig,
|
||||
} from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { lorasPersistConfig, lorasSlice } from 'features/controlLayers/store/lorasSlice';
|
||||
import { paramsPersistConfig, paramsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { refImagesPersistConfig, refImagesSlice } from 'features/controlLayers/store/refImagesSlice';
|
||||
import { dynamicPromptsPersistConfig, dynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
||||
import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/gallerySlice';
|
||||
import { modelManagerV2PersistConfig, modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { nodesPersistConfig, nodesSlice, nodesUndoableConfig } from 'features/nodes/store/nodesSlice';
|
||||
import { workflowLibraryPersistConfig, workflowLibrarySlice } from 'features/nodes/store/workflowLibrarySlice';
|
||||
import { workflowSettingsPersistConfig, workflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { upscalePersistConfig, upscaleSlice } from 'features/parameters/store/upscaleSlice';
|
||||
import { queueSlice } from 'features/queue/store/queueSlice';
|
||||
import { stylePresetPersistConfig, stylePresetSlice } from 'features/stylePresets/store/stylePresetSlice';
|
||||
import { configSlice } from 'features/system/store/configSlice';
|
||||
import { systemPersistConfig, systemSlice } from 'features/system/store/systemSlice';
|
||||
import { uiPersistConfig, uiSlice } from 'features/ui/store/uiSlice';
|
||||
import { changeBoardModalSliceConfig } from 'features/changeBoardModal/store/slice';
|
||||
import { canvasSettingsSliceConfig } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { canvasSliceConfig } from 'features/controlLayers/store/canvasSlice';
|
||||
import { canvasSessionSliceConfig } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { lorasSliceConfig } from 'features/controlLayers/store/lorasSlice';
|
||||
import { paramsSliceConfig } from 'features/controlLayers/store/paramsSlice';
|
||||
import { refImagesSliceConfig } from 'features/controlLayers/store/refImagesSlice';
|
||||
import { dynamicPromptsSliceConfig } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
||||
import { gallerySliceConfig } from 'features/gallery/store/gallerySlice';
|
||||
import { modelManagerSliceConfig } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { nodesSliceConfig } from 'features/nodes/store/nodesSlice';
|
||||
import { workflowLibrarySliceConfig } from 'features/nodes/store/workflowLibrarySlice';
|
||||
import { workflowSettingsSliceConfig } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { upscaleSliceConfig } from 'features/parameters/store/upscaleSlice';
|
||||
import { queueSliceConfig } from 'features/queue/store/queueSlice';
|
||||
import { stylePresetSliceConfig } from 'features/stylePresets/store/stylePresetSlice';
|
||||
import { configSliceConfig } from 'features/system/store/configSlice';
|
||||
import { systemSliceConfig } from 'features/system/store/systemSlice';
|
||||
import { uiSliceConfig } from 'features/ui/store/uiSlice';
|
||||
import { diff } from 'jsondiffpatch';
|
||||
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
||||
import type { SerializeFunction, UnserializeFunction } from 'redux-remember';
|
||||
import type { Driver, SerializeFunction, UnserializeFunction } from 'redux-remember';
|
||||
import { rememberEnhancer, rememberReducer } from 'redux-remember';
|
||||
import undoable, { newHistory } from 'redux-undo';
|
||||
import { serializeError } from 'serialize-error';
|
||||
@ -37,123 +48,116 @@ import { api } from 'services/api';
|
||||
import { authToastMiddleware } from 'services/api/authToastMiddleware';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
|
||||
import { STORAGE_PREFIX } from './constants';
|
||||
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
|
||||
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
|
||||
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
|
||||
import { listenerMiddleware } from './middleware/listenerMiddleware';
|
||||
import { addArchivedOrDeletedBoardListener } from './middleware/listenerMiddleware/listeners/addArchivedOrDeletedBoardListener';
|
||||
import { addImageUploadedFulfilledListener } from './middleware/listenerMiddleware/listeners/imageUploaded';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
const allReducers = {
|
||||
[api.reducerPath]: api.reducer,
|
||||
[gallerySlice.name]: gallerySlice.reducer,
|
||||
[nodesSlice.name]: undoable(nodesSlice.reducer, nodesUndoableConfig),
|
||||
[systemSlice.name]: systemSlice.reducer,
|
||||
[configSlice.name]: configSlice.reducer,
|
||||
[uiSlice.name]: uiSlice.reducer,
|
||||
[dynamicPromptsSlice.name]: dynamicPromptsSlice.reducer,
|
||||
[changeBoardModalSlice.name]: changeBoardModalSlice.reducer,
|
||||
[modelManagerV2Slice.name]: modelManagerV2Slice.reducer,
|
||||
[queueSlice.name]: queueSlice.reducer,
|
||||
[canvasSlice.name]: undoable(canvasSlice.reducer, canvasUndoableConfig),
|
||||
[workflowSettingsSlice.name]: workflowSettingsSlice.reducer,
|
||||
[upscaleSlice.name]: upscaleSlice.reducer,
|
||||
[stylePresetSlice.name]: stylePresetSlice.reducer,
|
||||
[paramsSlice.name]: paramsSlice.reducer,
|
||||
[canvasSettingsSlice.name]: canvasSettingsSlice.reducer,
|
||||
[canvasSessionSlice.name]: canvasSessionSlice.reducer,
|
||||
[lorasSlice.name]: lorasSlice.reducer,
|
||||
[workflowLibrarySlice.name]: workflowLibrarySlice.reducer,
|
||||
[refImagesSlice.name]: refImagesSlice.reducer,
|
||||
// When adding a slice, add the config to the SLICE_CONFIGS object below, then add the reducer to ALL_REDUCERS.
|
||||
const SLICE_CONFIGS = {
|
||||
[canvasSessionSliceConfig.slice.reducerPath]: canvasSessionSliceConfig,
|
||||
[canvasSettingsSliceConfig.slice.reducerPath]: canvasSettingsSliceConfig,
|
||||
[canvasSliceConfig.slice.reducerPath]: canvasSliceConfig,
|
||||
[changeBoardModalSliceConfig.slice.reducerPath]: changeBoardModalSliceConfig,
|
||||
[configSliceConfig.slice.reducerPath]: configSliceConfig,
|
||||
[dynamicPromptsSliceConfig.slice.reducerPath]: dynamicPromptsSliceConfig,
|
||||
[gallerySliceConfig.slice.reducerPath]: gallerySliceConfig,
|
||||
[lorasSliceConfig.slice.reducerPath]: lorasSliceConfig,
|
||||
[modelManagerSliceConfig.slice.reducerPath]: modelManagerSliceConfig,
|
||||
[nodesSliceConfig.slice.reducerPath]: nodesSliceConfig,
|
||||
[paramsSliceConfig.slice.reducerPath]: paramsSliceConfig,
|
||||
[queueSliceConfig.slice.reducerPath]: queueSliceConfig,
|
||||
[refImagesSliceConfig.slice.reducerPath]: refImagesSliceConfig,
|
||||
[stylePresetSliceConfig.slice.reducerPath]: stylePresetSliceConfig,
|
||||
[systemSliceConfig.slice.reducerPath]: systemSliceConfig,
|
||||
[uiSliceConfig.slice.reducerPath]: uiSliceConfig,
|
||||
[upscaleSliceConfig.slice.reducerPath]: upscaleSliceConfig,
|
||||
[workflowLibrarySliceConfig.slice.reducerPath]: workflowLibrarySliceConfig,
|
||||
[workflowSettingsSliceConfig.slice.reducerPath]: workflowSettingsSliceConfig,
|
||||
};
|
||||
|
||||
const rootReducer = combineReducers(allReducers);
|
||||
// TS makes it really hard to dynamically create this object :/ so it's just hardcoded here.
|
||||
// Remember to wrap undoable reducers in `undoable()`!
|
||||
const ALL_REDUCERS = {
|
||||
[api.reducerPath]: api.reducer,
|
||||
[canvasSessionSliceConfig.slice.reducerPath]: canvasSessionSliceConfig.slice.reducer,
|
||||
[canvasSettingsSliceConfig.slice.reducerPath]: canvasSettingsSliceConfig.slice.reducer,
|
||||
// Undoable!
|
||||
[canvasSliceConfig.slice.reducerPath]: undoable(
|
||||
canvasSliceConfig.slice.reducer,
|
||||
canvasSliceConfig.undoableConfig?.reduxUndoOptions
|
||||
),
|
||||
[changeBoardModalSliceConfig.slice.reducerPath]: changeBoardModalSliceConfig.slice.reducer,
|
||||
[configSliceConfig.slice.reducerPath]: configSliceConfig.slice.reducer,
|
||||
[dynamicPromptsSliceConfig.slice.reducerPath]: dynamicPromptsSliceConfig.slice.reducer,
|
||||
[gallerySliceConfig.slice.reducerPath]: gallerySliceConfig.slice.reducer,
|
||||
[lorasSliceConfig.slice.reducerPath]: lorasSliceConfig.slice.reducer,
|
||||
[modelManagerSliceConfig.slice.reducerPath]: modelManagerSliceConfig.slice.reducer,
|
||||
// Undoable!
|
||||
[nodesSliceConfig.slice.reducerPath]: undoable(
|
||||
nodesSliceConfig.slice.reducer,
|
||||
nodesSliceConfig.undoableConfig?.reduxUndoOptions
|
||||
),
|
||||
[paramsSliceConfig.slice.reducerPath]: paramsSliceConfig.slice.reducer,
|
||||
[queueSliceConfig.slice.reducerPath]: queueSliceConfig.slice.reducer,
|
||||
[refImagesSliceConfig.slice.reducerPath]: refImagesSliceConfig.slice.reducer,
|
||||
[stylePresetSliceConfig.slice.reducerPath]: stylePresetSliceConfig.slice.reducer,
|
||||
[systemSliceConfig.slice.reducerPath]: systemSliceConfig.slice.reducer,
|
||||
[uiSliceConfig.slice.reducerPath]: uiSliceConfig.slice.reducer,
|
||||
[upscaleSliceConfig.slice.reducerPath]: upscaleSliceConfig.slice.reducer,
|
||||
[workflowLibrarySliceConfig.slice.reducerPath]: workflowLibrarySliceConfig.slice.reducer,
|
||||
[workflowSettingsSliceConfig.slice.reducerPath]: workflowSettingsSliceConfig.slice.reducer,
|
||||
};
|
||||
|
||||
const rootReducer = combineReducers(ALL_REDUCERS);
|
||||
|
||||
const rememberedRootReducer = rememberReducer(rootReducer);
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
export type PersistConfig<T = any> = {
|
||||
/**
|
||||
* The name of the slice.
|
||||
*/
|
||||
name: keyof typeof allReducers;
|
||||
/**
|
||||
* The initial state of the slice.
|
||||
*/
|
||||
initialState: T;
|
||||
/**
|
||||
* Migrate the state to the current version during rehydration.
|
||||
* @param state The rehydrated state.
|
||||
* @returns A correctly-shaped state.
|
||||
*/
|
||||
migrate: (state: unknown) => T;
|
||||
/**
|
||||
* Keys to omit from the persisted state.
|
||||
*/
|
||||
persistDenylist: (keyof T)[];
|
||||
};
|
||||
|
||||
const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
|
||||
[galleryPersistConfig.name]: galleryPersistConfig,
|
||||
[nodesPersistConfig.name]: nodesPersistConfig,
|
||||
[systemPersistConfig.name]: systemPersistConfig,
|
||||
[uiPersistConfig.name]: uiPersistConfig,
|
||||
[dynamicPromptsPersistConfig.name]: dynamicPromptsPersistConfig,
|
||||
[modelManagerV2PersistConfig.name]: modelManagerV2PersistConfig,
|
||||
[canvasPersistConfig.name]: canvasPersistConfig,
|
||||
[workflowSettingsPersistConfig.name]: workflowSettingsPersistConfig,
|
||||
[upscalePersistConfig.name]: upscalePersistConfig,
|
||||
[stylePresetPersistConfig.name]: stylePresetPersistConfig,
|
||||
[paramsPersistConfig.name]: paramsPersistConfig,
|
||||
[canvasSettingsPersistConfig.name]: canvasSettingsPersistConfig,
|
||||
[canvasStagingAreaPersistConfig.name]: canvasStagingAreaPersistConfig,
|
||||
[lorasPersistConfig.name]: lorasPersistConfig,
|
||||
[workflowLibraryPersistConfig.name]: workflowLibraryPersistConfig,
|
||||
[refImagesSlice.name]: refImagesPersistConfig,
|
||||
};
|
||||
|
||||
const unserialize: UnserializeFunction = (data, key) => {
|
||||
const persistConfig = persistConfigs[key as keyof typeof persistConfigs];
|
||||
if (!persistConfig) {
|
||||
const sliceConfig = SLICE_CONFIGS[key as keyof typeof SLICE_CONFIGS];
|
||||
if (!sliceConfig?.persistConfig) {
|
||||
throw new Error(`No persist config for slice "${key}"`);
|
||||
}
|
||||
const { getInitialState, persistConfig, undoableConfig } = sliceConfig;
|
||||
let state;
|
||||
try {
|
||||
const { initialState, migrate } = persistConfig;
|
||||
const parsed = JSON.parse(data);
|
||||
const initialState = getInitialState();
|
||||
|
||||
// strip out old keys
|
||||
const stripped = pick(deepClone(parsed), keys(initialState));
|
||||
// run (additive) migrations
|
||||
const migrated = migrate(stripped);
|
||||
const stripped = pick(deepClone(data), keys(initialState));
|
||||
/*
|
||||
* Merge in initial state as default values, covering any missing keys. You might be tempted to use _.defaultsDeep,
|
||||
* but that merges arrays by index and partial objects by key. Using an identity function as the customizer results
|
||||
* in behaviour like defaultsDeep, but doesn't overwrite any values that are not undefined in the migrated state.
|
||||
*/
|
||||
const transformed = mergeWith(migrated, initialState, (objVal) => objVal);
|
||||
const unPersistDenylisted = mergeWith(stripped, initialState, (objVal) => objVal);
|
||||
// run (additive) migrations
|
||||
const migrated = persistConfig.migrate(unPersistDenylisted);
|
||||
|
||||
log.debug(
|
||||
{
|
||||
persistedData: parsed,
|
||||
rehydratedData: transformed,
|
||||
diff: diff(parsed, transformed) as JsonObject, // this is always serializable
|
||||
persistedData: data as JsonObject,
|
||||
rehydratedData: migrated as JsonObject,
|
||||
diff: diff(data, migrated) as JsonObject,
|
||||
},
|
||||
`Rehydrated slice "${key}"`
|
||||
);
|
||||
state = transformed;
|
||||
state = migrated;
|
||||
} catch (err) {
|
||||
log.warn(
|
||||
{ error: serializeError(err as Error) },
|
||||
`Error rehydrating slice "${key}", falling back to default initial state`
|
||||
);
|
||||
state = persistConfig.initialState;
|
||||
state = getInitialState();
|
||||
}
|
||||
|
||||
// If the slice is undoable, we need to wrap it in a new history - only nodes and canvas are undoable at the moment.
|
||||
// TODO(psyche): make this automatic & remove the hard-coding for specific slices.
|
||||
if (key === nodesSlice.name || key === canvasSlice.name) {
|
||||
// Undoable slices must be wrapped in a history!
|
||||
if (undoableConfig) {
|
||||
return newHistory([], state, []);
|
||||
} else {
|
||||
return state;
|
||||
@ -161,21 +165,30 @@ const unserialize: UnserializeFunction = (data, key) => {
|
||||
};
|
||||
|
||||
const serialize: SerializeFunction = (data, key) => {
|
||||
const persistConfig = persistConfigs[key as keyof typeof persistConfigs];
|
||||
if (!persistConfig) {
|
||||
const sliceConfig = SLICE_CONFIGS[key as keyof typeof SLICE_CONFIGS];
|
||||
if (!sliceConfig?.persistConfig) {
|
||||
throw new Error(`No persist config for slice "${key}"`);
|
||||
}
|
||||
// Heuristic to determine if the slice is undoable - could just hardcode it in the persistConfig
|
||||
const isUndoable = 'present' in data && 'past' in data && 'future' in data && '_latestUnfiltered' in data;
|
||||
const result = omit(isUndoable ? data.present : data, persistConfig.persistDenylist);
|
||||
|
||||
const result = omit(
|
||||
sliceConfig.undoableConfig ? data.present : data,
|
||||
sliceConfig.persistConfig.persistDenylist ?? []
|
||||
);
|
||||
|
||||
return JSON.stringify(result);
|
||||
};
|
||||
|
||||
export const createStore = (uniqueStoreKey?: string, persist = true) =>
|
||||
const PERSISTED_KEYS = Object.values(SLICE_CONFIGS)
|
||||
.filter((sliceConfig) => !!sliceConfig.persistConfig)
|
||||
.map((sliceConfig) => sliceConfig.slice.reducerPath);
|
||||
|
||||
export const createStore = (reduxRememberOptions: { driver: Driver; persistThrottle: number }) =>
|
||||
configureStore({
|
||||
reducer: rememberedRootReducer,
|
||||
middleware: (getDefaultMiddleware) =>
|
||||
getDefaultMiddleware({
|
||||
// serializableCheck: false,
|
||||
// immutableCheck: false,
|
||||
serializableCheck: import.meta.env.MODE === 'development',
|
||||
immutableCheck: import.meta.env.MODE === 'development',
|
||||
})
|
||||
@ -185,19 +198,16 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
|
||||
// .concat(getDebugLoggerMiddleware())
|
||||
.prepend(listenerMiddleware.middleware),
|
||||
enhancers: (getDefaultEnhancers) => {
|
||||
const _enhancers = getDefaultEnhancers().concat(autoBatchEnhancer());
|
||||
if (persist) {
|
||||
_enhancers.push(
|
||||
rememberEnhancer(idbKeyValDriver, keys(persistConfigs), {
|
||||
persistDebounce: 300,
|
||||
serialize,
|
||||
unserialize,
|
||||
prefix: uniqueStoreKey ? `${STORAGE_PREFIX}${uniqueStoreKey}-` : STORAGE_PREFIX,
|
||||
errorHandler,
|
||||
})
|
||||
);
|
||||
}
|
||||
return _enhancers;
|
||||
const enhancers = getDefaultEnhancers();
|
||||
return enhancers.prepend(
|
||||
rememberEnhancer(reduxRememberOptions.driver, PERSISTED_KEYS, {
|
||||
persistThrottle: reduxRememberOptions.persistThrottle,
|
||||
serialize,
|
||||
unserialize,
|
||||
prefix: '',
|
||||
errorHandler,
|
||||
})
|
||||
);
|
||||
},
|
||||
devTools: {
|
||||
actionSanitizer,
|
||||
@ -214,7 +224,48 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
|
||||
|
||||
export type AppStore = ReturnType<typeof createStore>;
|
||||
export type RootState = ReturnType<AppStore['getState']>;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
export type AppThunkDispatch = ThunkDispatch<RootState, any, UnknownAction>;
|
||||
export type AppDispatch = ReturnType<typeof createStore>['dispatch'];
|
||||
export type AppGetState = ReturnType<typeof createStore>['getState'];
|
||||
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
|
||||
|
||||
export const addAppListener = addListener.withTypes<RootState, AppDispatch>();
|
||||
|
||||
const startAppListening = listenerMiddleware.startListening as AppStartListening;
|
||||
addImageUploadedFulfilledListener(startAppListening);
|
||||
|
||||
// Image deleted
|
||||
addDeleteBoardAndImagesFulfilledListener(startAppListening);
|
||||
|
||||
// User Invoked
|
||||
addAnyEnqueuedListener(startAppListening);
|
||||
addBatchEnqueuedListener(startAppListening);
|
||||
|
||||
// Socket.IO
|
||||
addSocketConnectedEventListener(startAppListening);
|
||||
|
||||
// Gallery bulk download
|
||||
addBulkDownloadListeners(startAppListening);
|
||||
|
||||
// Boards
|
||||
addImageAddedToBoardFulfilledListener(startAppListening);
|
||||
addImageRemovedFromBoardFulfilledListener(startAppListening);
|
||||
addBoardIdSelectedListener(startAppListening);
|
||||
addArchivedOrDeletedBoardListener(startAppListening);
|
||||
|
||||
// Node schemas
|
||||
addGetOpenAPISchemaListener(startAppListening);
|
||||
|
||||
// Models
|
||||
addModelSelectedListener(startAppListening);
|
||||
|
||||
// app startup
|
||||
addAppStartedListener(startAppListening);
|
||||
addModelsLoadedListener(startAppListening);
|
||||
addAppConfigReceivedListener(startAppListening);
|
||||
|
||||
// Ad-hoc upscale workflwo
|
||||
addAdHocPostProcessingRequestedListener(startAppListening);
|
||||
|
||||
addSetDefaultSettingsListener(startAppListening);
|
||||
|
46
invokeai/frontend/web/src/app/store/types.ts
Normal file
46
invokeai/frontend/web/src/app/store/types.ts
Normal file
@ -0,0 +1,46 @@
|
||||
import type { Slice } from '@reduxjs/toolkit';
|
||||
import type { UndoableOptions } from 'redux-undo';
|
||||
import type { ZodType } from 'zod';
|
||||
|
||||
type StateFromSlice<T extends Slice> = T extends Slice<infer U> ? U : never;
|
||||
|
||||
export type SliceConfig<T extends Slice> = {
|
||||
/**
|
||||
* The redux slice (return of createSlice).
|
||||
*/
|
||||
slice: T;
|
||||
/**
|
||||
* The zod schema for the slice.
|
||||
*/
|
||||
schema: ZodType<StateFromSlice<T>>;
|
||||
/**
|
||||
* A function that returns the initial state of the slice.
|
||||
*/
|
||||
getInitialState: () => StateFromSlice<T>;
|
||||
/**
|
||||
* The optional persist configuration for this slice. If omitted, the slice will not be persisted.
|
||||
*/
|
||||
persistConfig?: {
|
||||
/**
|
||||
* Migrate the state to the current version during rehydration. This method should throw an error if the migration
|
||||
* fails.
|
||||
*
|
||||
* @param state The rehydrated state.
|
||||
* @returns A correctly-shaped state.
|
||||
*/
|
||||
migrate: (state: unknown) => StateFromSlice<T>;
|
||||
/**
|
||||
* Keys to omit from the persisted state.
|
||||
*/
|
||||
persistDenylist?: (keyof StateFromSlice<T>)[];
|
||||
};
|
||||
/**
|
||||
* The optional undoable configuration for this slice. If omitted, the slice will not be undoable.
|
||||
*/
|
||||
undoableConfig?: {
|
||||
/**
|
||||
* The options to be passed into redux-undo.
|
||||
*/
|
||||
reduxUndoOptions: UndoableOptions<StateFromSlice<T>>;
|
||||
};
|
||||
};
|
@ -1,130 +1,299 @@
|
||||
import type { FilterType } from 'features/controlLayers/store/filters';
|
||||
import type { ParameterPrecision, ParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||
import type { TabName } from 'features/ui/store/uiTypes';
|
||||
import { zFilterType } from 'features/controlLayers/store/filters';
|
||||
import { zParameterPrecision, zParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||
import { zTabName } from 'features/ui/store/uiTypes';
|
||||
import type { PartialDeep } from 'type-fest';
|
||||
import z from 'zod';
|
||||
|
||||
/**
|
||||
* A disable-able application feature
|
||||
*/
|
||||
export type AppFeature =
|
||||
| 'faceRestore'
|
||||
| 'upscaling'
|
||||
| 'lightbox'
|
||||
| 'modelManager'
|
||||
| 'githubLink'
|
||||
| 'discordLink'
|
||||
| 'bugLink'
|
||||
| 'aboutModal'
|
||||
| 'localization'
|
||||
| 'consoleLogging'
|
||||
| 'dynamicPrompting'
|
||||
| 'batches'
|
||||
| 'syncModels'
|
||||
| 'multiselect'
|
||||
| 'pauseQueue'
|
||||
| 'resumeQueue'
|
||||
| 'invocationCache'
|
||||
| 'modelCache'
|
||||
| 'bulkDownload'
|
||||
| 'starterModels'
|
||||
| 'hfToken'
|
||||
| 'retryQueueItem'
|
||||
| 'cancelAndClearAll'
|
||||
| 'chatGPT4oHigh'
|
||||
| 'modelRelationships';
|
||||
/**
|
||||
* A disable-able Stable Diffusion feature
|
||||
*/
|
||||
export type SDFeature =
|
||||
| 'controlNet'
|
||||
| 'noise'
|
||||
| 'perlinNoise'
|
||||
| 'noiseThreshold'
|
||||
| 'variation'
|
||||
| 'symmetry'
|
||||
| 'seamless'
|
||||
| 'hires'
|
||||
| 'lora'
|
||||
| 'embedding'
|
||||
| 'vae'
|
||||
| 'hrf';
|
||||
const zAppFeature = z.enum([
|
||||
'faceRestore',
|
||||
'upscaling',
|
||||
'lightbox',
|
||||
'modelManager',
|
||||
'githubLink',
|
||||
'discordLink',
|
||||
'bugLink',
|
||||
'aboutModal',
|
||||
'localization',
|
||||
'consoleLogging',
|
||||
'dynamicPrompting',
|
||||
'batches',
|
||||
'syncModels',
|
||||
'multiselect',
|
||||
'pauseQueue',
|
||||
'resumeQueue',
|
||||
'invocationCache',
|
||||
'modelCache',
|
||||
'bulkDownload',
|
||||
'starterModels',
|
||||
'hfToken',
|
||||
'retryQueueItem',
|
||||
'cancelAndClearAll',
|
||||
'chatGPT4oHigh',
|
||||
'modelRelationships',
|
||||
]);
|
||||
export type AppFeature = z.infer<typeof zAppFeature>;
|
||||
|
||||
export type NumericalParameterConfig = {
|
||||
initial: number;
|
||||
sliderMin: number;
|
||||
sliderMax: number;
|
||||
numberInputMin: number;
|
||||
numberInputMax: number;
|
||||
fineStep: number;
|
||||
coarseStep: number;
|
||||
};
|
||||
const zSDFeature = z.enum([
|
||||
'controlNet',
|
||||
'noise',
|
||||
'perlinNoise',
|
||||
'noiseThreshold',
|
||||
'variation',
|
||||
'symmetry',
|
||||
'seamless',
|
||||
'hires',
|
||||
'lora',
|
||||
'embedding',
|
||||
'vae',
|
||||
'hrf',
|
||||
]);
|
||||
export type SDFeature = z.infer<typeof zSDFeature>;
|
||||
|
||||
const zNumericalParameterConfig = z.object({
|
||||
initial: z.number().default(512),
|
||||
sliderMin: z.number().default(64),
|
||||
sliderMax: z.number().default(1536),
|
||||
numberInputMin: z.number().default(64),
|
||||
numberInputMax: z.number().default(4096),
|
||||
fineStep: z.number().default(8),
|
||||
coarseStep: z.number().default(64),
|
||||
});
|
||||
|
||||
/**
|
||||
* Configuration options for the InvokeAI UI.
|
||||
* Distinct from system settings which may be changed inside the app.
|
||||
*/
|
||||
export type AppConfig = {
|
||||
export const zAppConfig = z.object({
|
||||
/**
|
||||
* Whether or not we should update image urls when image loading errors
|
||||
*/
|
||||
shouldUpdateImagesOnConnect: boolean;
|
||||
shouldFetchMetadataFromApi: boolean;
|
||||
shouldUpdateImagesOnConnect: z.boolean(),
|
||||
shouldFetchMetadataFromApi: z.boolean(),
|
||||
/**
|
||||
* Sets a size limit for outputs on the upscaling tab. This is a maximum dimension, so the actual max number of pixels
|
||||
* will be the square of this value.
|
||||
*/
|
||||
maxUpscaleDimension?: number;
|
||||
allowPrivateBoards: boolean;
|
||||
allowPrivateStylePresets: boolean;
|
||||
allowClientSideUpload: boolean;
|
||||
allowPublishWorkflows: boolean;
|
||||
allowPromptExpansion: boolean;
|
||||
disabledTabs: TabName[];
|
||||
disabledFeatures: AppFeature[];
|
||||
disabledSDFeatures: SDFeature[];
|
||||
nodesAllowlist: string[] | undefined;
|
||||
nodesDenylist: string[] | undefined;
|
||||
metadataFetchDebounce?: number;
|
||||
workflowFetchDebounce?: number;
|
||||
isLocal?: boolean;
|
||||
shouldShowCredits: boolean;
|
||||
sd: {
|
||||
defaultModel?: string;
|
||||
disabledControlNetModels: string[];
|
||||
disabledControlNetProcessors: FilterType[];
|
||||
maxUpscaleDimension: z.number().optional(),
|
||||
allowPrivateBoards: z.boolean(),
|
||||
allowPrivateStylePresets: z.boolean(),
|
||||
allowClientSideUpload: z.boolean(),
|
||||
allowPublishWorkflows: z.boolean(),
|
||||
allowPromptExpansion: z.boolean(),
|
||||
disabledTabs: z.array(zTabName),
|
||||
disabledFeatures: z.array(zAppFeature),
|
||||
disabledSDFeatures: z.array(zSDFeature),
|
||||
nodesAllowlist: z.array(z.string()).optional(),
|
||||
nodesDenylist: z.array(z.string()).optional(),
|
||||
metadataFetchDebounce: z.number().int().optional(),
|
||||
workflowFetchDebounce: z.number().int().optional(),
|
||||
isLocal: z.boolean().optional(),
|
||||
shouldShowCredits: z.boolean().optional(),
|
||||
sd: z.object({
|
||||
defaultModel: z.string().optional(),
|
||||
disabledControlNetModels: z.array(z.string()),
|
||||
disabledControlNetProcessors: z.array(zFilterType),
|
||||
// Core parameters
|
||||
iterations: NumericalParameterConfig;
|
||||
width: NumericalParameterConfig; // initial value comes from model
|
||||
height: NumericalParameterConfig; // initial value comes from model
|
||||
steps: NumericalParameterConfig;
|
||||
guidance: NumericalParameterConfig;
|
||||
cfgRescaleMultiplier: NumericalParameterConfig;
|
||||
img2imgStrength: NumericalParameterConfig;
|
||||
scheduler?: ParameterScheduler;
|
||||
vaePrecision?: ParameterPrecision;
|
||||
iterations: zNumericalParameterConfig,
|
||||
width: zNumericalParameterConfig,
|
||||
height: zNumericalParameterConfig,
|
||||
steps: zNumericalParameterConfig,
|
||||
guidance: zNumericalParameterConfig,
|
||||
cfgRescaleMultiplier: zNumericalParameterConfig,
|
||||
img2imgStrength: zNumericalParameterConfig,
|
||||
scheduler: zParameterScheduler.optional(),
|
||||
vaePrecision: zParameterPrecision.optional(),
|
||||
// Canvas
|
||||
boundingBoxHeight: NumericalParameterConfig; // initial value comes from model
|
||||
boundingBoxWidth: NumericalParameterConfig; // initial value comes from model
|
||||
scaledBoundingBoxHeight: NumericalParameterConfig; // initial value comes from model
|
||||
scaledBoundingBoxWidth: NumericalParameterConfig; // initial value comes from model
|
||||
canvasCoherenceStrength: NumericalParameterConfig;
|
||||
canvasCoherenceEdgeSize: NumericalParameterConfig;
|
||||
infillTileSize: NumericalParameterConfig;
|
||||
infillPatchmatchDownscaleSize: NumericalParameterConfig;
|
||||
boundingBoxHeight: zNumericalParameterConfig,
|
||||
boundingBoxWidth: zNumericalParameterConfig,
|
||||
scaledBoundingBoxHeight: zNumericalParameterConfig,
|
||||
scaledBoundingBoxWidth: zNumericalParameterConfig,
|
||||
canvasCoherenceStrength: zNumericalParameterConfig,
|
||||
canvasCoherenceEdgeSize: zNumericalParameterConfig,
|
||||
infillTileSize: zNumericalParameterConfig,
|
||||
infillPatchmatchDownscaleSize: zNumericalParameterConfig,
|
||||
// Misc advanced
|
||||
clipSkip: NumericalParameterConfig; // slider and input max are ignored for this, because the values depend on the model
|
||||
maskBlur: NumericalParameterConfig;
|
||||
hrfStrength: NumericalParameterConfig;
|
||||
dynamicPrompts: {
|
||||
maxPrompts: NumericalParameterConfig;
|
||||
};
|
||||
ca: {
|
||||
weight: NumericalParameterConfig;
|
||||
};
|
||||
};
|
||||
flux: {
|
||||
guidance: NumericalParameterConfig;
|
||||
};
|
||||
};
|
||||
clipSkip: zNumericalParameterConfig, // slider and input max are ignored for this, because the values depend on the model
|
||||
maskBlur: zNumericalParameterConfig,
|
||||
hrfStrength: zNumericalParameterConfig,
|
||||
dynamicPrompts: z.object({
|
||||
maxPrompts: zNumericalParameterConfig,
|
||||
}),
|
||||
ca: z.object({
|
||||
weight: zNumericalParameterConfig,
|
||||
}),
|
||||
}),
|
||||
flux: z.object({
|
||||
guidance: zNumericalParameterConfig,
|
||||
}),
|
||||
});
|
||||
|
||||
export type AppConfig = z.infer<typeof zAppConfig>;
|
||||
export type PartialAppConfig = PartialDeep<AppConfig>;
|
||||
|
||||
export const getDefaultAppConfig = (): AppConfig => ({
|
||||
isLocal: true,
|
||||
shouldUpdateImagesOnConnect: false,
|
||||
shouldFetchMetadataFromApi: false,
|
||||
allowPrivateBoards: false,
|
||||
allowPrivateStylePresets: false,
|
||||
allowClientSideUpload: false,
|
||||
allowPublishWorkflows: false,
|
||||
allowPromptExpansion: false,
|
||||
shouldShowCredits: false,
|
||||
disabledTabs: [],
|
||||
disabledFeatures: ['lightbox', 'faceRestore', 'batches'] satisfies AppFeature[],
|
||||
disabledSDFeatures: ['variation', 'symmetry', 'hires', 'perlinNoise', 'noiseThreshold'] satisfies SDFeature[],
|
||||
sd: {
|
||||
disabledControlNetModels: [],
|
||||
disabledControlNetProcessors: [],
|
||||
iterations: {
|
||||
initial: 1,
|
||||
sliderMin: 1,
|
||||
sliderMax: 1000,
|
||||
numberInputMin: 1,
|
||||
numberInputMax: 10000,
|
||||
fineStep: 1,
|
||||
coarseStep: 1,
|
||||
},
|
||||
width: zNumericalParameterConfig.parse({}), // initial value comes from model
|
||||
height: zNumericalParameterConfig.parse({}), // initial value comes from model
|
||||
boundingBoxWidth: zNumericalParameterConfig.parse({}), // initial value comes from model
|
||||
boundingBoxHeight: zNumericalParameterConfig.parse({}), // initial value comes from model
|
||||
scaledBoundingBoxWidth: zNumericalParameterConfig.parse({}), // initial value comes from model
|
||||
scaledBoundingBoxHeight: zNumericalParameterConfig.parse({}), // initial value comes from model
|
||||
scheduler: 'dpmpp_3m_k' as const,
|
||||
vaePrecision: 'fp32' as const,
|
||||
steps: {
|
||||
initial: 30,
|
||||
sliderMin: 1,
|
||||
sliderMax: 100,
|
||||
numberInputMin: 1,
|
||||
numberInputMax: 500,
|
||||
fineStep: 1,
|
||||
coarseStep: 1,
|
||||
},
|
||||
guidance: {
|
||||
initial: 7,
|
||||
sliderMin: 1,
|
||||
sliderMax: 20,
|
||||
numberInputMin: 1,
|
||||
numberInputMax: 200,
|
||||
fineStep: 0.1,
|
||||
coarseStep: 0.5,
|
||||
},
|
||||
img2imgStrength: {
|
||||
initial: 0.7,
|
||||
sliderMin: 0,
|
||||
sliderMax: 1,
|
||||
numberInputMin: 0,
|
||||
numberInputMax: 1,
|
||||
fineStep: 0.01,
|
||||
coarseStep: 0.05,
|
||||
},
|
||||
canvasCoherenceStrength: {
|
||||
initial: 0.3,
|
||||
sliderMin: 0,
|
||||
sliderMax: 1,
|
||||
numberInputMin: 0,
|
||||
numberInputMax: 1,
|
||||
fineStep: 0.01,
|
||||
coarseStep: 0.05,
|
||||
},
|
||||
hrfStrength: {
|
||||
initial: 0.45,
|
||||
sliderMin: 0,
|
||||
sliderMax: 1,
|
||||
numberInputMin: 0,
|
||||
numberInputMax: 1,
|
||||
fineStep: 0.01,
|
||||
coarseStep: 0.05,
|
||||
},
|
||||
canvasCoherenceEdgeSize: {
|
||||
initial: 16,
|
||||
sliderMin: 0,
|
||||
sliderMax: 128,
|
||||
numberInputMin: 0,
|
||||
numberInputMax: 1024,
|
||||
fineStep: 8,
|
||||
coarseStep: 16,
|
||||
},
|
||||
cfgRescaleMultiplier: {
|
||||
initial: 0,
|
||||
sliderMin: 0,
|
||||
sliderMax: 0.99,
|
||||
numberInputMin: 0,
|
||||
numberInputMax: 0.99,
|
||||
fineStep: 0.05,
|
||||
coarseStep: 0.1,
|
||||
},
|
||||
clipSkip: {
|
||||
initial: 0,
|
||||
sliderMin: 0,
|
||||
sliderMax: 12, // determined by model selection, unused in practice
|
||||
numberInputMin: 0,
|
||||
numberInputMax: 12, // determined by model selection, unused in practice
|
||||
fineStep: 1,
|
||||
coarseStep: 1,
|
||||
},
|
||||
infillPatchmatchDownscaleSize: {
|
||||
initial: 1,
|
||||
sliderMin: 1,
|
||||
sliderMax: 10,
|
||||
numberInputMin: 1,
|
||||
numberInputMax: 10,
|
||||
fineStep: 1,
|
||||
coarseStep: 1,
|
||||
},
|
||||
infillTileSize: {
|
||||
initial: 32,
|
||||
sliderMin: 16,
|
||||
sliderMax: 64,
|
||||
numberInputMin: 16,
|
||||
numberInputMax: 256,
|
||||
fineStep: 1,
|
||||
coarseStep: 1,
|
||||
},
|
||||
maskBlur: {
|
||||
initial: 16,
|
||||
sliderMin: 0,
|
||||
sliderMax: 128,
|
||||
numberInputMin: 0,
|
||||
numberInputMax: 512,
|
||||
fineStep: 1,
|
||||
coarseStep: 1,
|
||||
},
|
||||
ca: {
|
||||
weight: {
|
||||
initial: 1,
|
||||
sliderMin: 0,
|
||||
sliderMax: 2,
|
||||
numberInputMin: -1,
|
||||
numberInputMax: 2,
|
||||
fineStep: 0.01,
|
||||
coarseStep: 0.05,
|
||||
},
|
||||
},
|
||||
dynamicPrompts: {
|
||||
maxPrompts: {
|
||||
initial: 100,
|
||||
sliderMin: 1,
|
||||
sliderMax: 1000,
|
||||
numberInputMin: 1,
|
||||
numberInputMax: 10000,
|
||||
fineStep: 1,
|
||||
coarseStep: 10,
|
||||
},
|
||||
},
|
||||
},
|
||||
flux: {
|
||||
guidance: {
|
||||
initial: 4,
|
||||
sliderMin: 2,
|
||||
sliderMax: 6,
|
||||
numberInputMin: 1,
|
||||
numberInputMax: 20,
|
||||
fineStep: 0.1,
|
||||
coarseStep: 0.5,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
@ -1,11 +0,0 @@
|
||||
import { clearIdbKeyValStore } from 'app/store/enhancers/reduxRemember/driver';
|
||||
import { useCallback } from 'react';
|
||||
|
||||
export const useClearStorage = () => {
|
||||
const clearStorage = useCallback(() => {
|
||||
clearIdbKeyValStore();
|
||||
localStorage.clear();
|
||||
}, []);
|
||||
|
||||
return clearStorage;
|
||||
};
|
@ -1,6 +0,0 @@
|
||||
import type { ChangeBoardModalState } from './types';
|
||||
|
||||
export const initialState: ChangeBoardModalState = {
|
||||
isModalOpen: false,
|
||||
image_names: [],
|
||||
};
|
@ -1,12 +1,20 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { SliceConfig } from 'app/store/types';
|
||||
import z from 'zod';
|
||||
|
||||
import { initialState } from './initialState';
|
||||
const zChangeBoardModalState = z.object({
|
||||
isModalOpen: z.boolean().default(false),
|
||||
image_names: z.array(z.string()).default(() => []),
|
||||
});
|
||||
type ChangeBoardModalState = z.infer<typeof zChangeBoardModalState>;
|
||||
|
||||
export const changeBoardModalSlice = createSlice({
|
||||
const getInitialState = (): ChangeBoardModalState => zChangeBoardModalState.parse({});
|
||||
|
||||
const slice = createSlice({
|
||||
name: 'changeBoardModal',
|
||||
initialState,
|
||||
initialState: getInitialState(),
|
||||
reducers: {
|
||||
isModalOpenChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.isModalOpen = action.payload;
|
||||
@ -21,6 +29,12 @@ export const changeBoardModalSlice = createSlice({
|
||||
},
|
||||
});
|
||||
|
||||
export const { isModalOpenChanged, imagesToChangeSelected, changeBoardReset } = changeBoardModalSlice.actions;
|
||||
export const { isModalOpenChanged, imagesToChangeSelected, changeBoardReset } = slice.actions;
|
||||
|
||||
export const selectChangeBoardModalSlice = (state: RootState) => state.changeBoardModal;
|
||||
|
||||
export const changeBoardModalSliceConfig: SliceConfig<typeof slice> = {
|
||||
slice,
|
||||
schema: zChangeBoardModalState,
|
||||
getInitialState,
|
||||
};
|
||||
|
@ -1,4 +0,0 @@
|
||||
export type ChangeBoardModalState = {
|
||||
isModalOpen: boolean;
|
||||
image_names: string[];
|
||||
};
|
@ -1,7 +1,7 @@
|
||||
import { $alt, $ctrl, $meta, $shift } from '@invoke-ai/ui-library';
|
||||
import type { Selector } from '@reduxjs/toolkit';
|
||||
import { addAppListener } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppStore, RootState } from 'app/store/store';
|
||||
import { addAppListener } from 'app/store/store';
|
||||
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
|
||||
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
|
@ -1,6 +1,7 @@
|
||||
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
|
||||
import { createSelector, createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { SliceConfig } from 'app/store/types';
|
||||
import { zRgbaColor } from 'features/controlLayers/store/types';
|
||||
import { z } from 'zod';
|
||||
|
||||
@ -11,32 +12,32 @@ const zCanvasSettingsState = z.object({
|
||||
/**
|
||||
* Whether to show HUD (Heads-Up Display) on the canvas.
|
||||
*/
|
||||
showHUD: z.boolean().default(true),
|
||||
showHUD: z.boolean(),
|
||||
/**
|
||||
* Whether to clip lines and shapes to the generation bounding box. If disabled, lines and shapes will be clipped to
|
||||
* the canvas bounds.
|
||||
*/
|
||||
clipToBbox: z.boolean().default(false),
|
||||
clipToBbox: z.boolean(),
|
||||
/**
|
||||
* Whether to show a dynamic grid on the canvas. If disabled, a checkerboard pattern will be shown instead.
|
||||
*/
|
||||
dynamicGrid: z.boolean().default(false),
|
||||
dynamicGrid: z.boolean(),
|
||||
/**
|
||||
* Whether to invert the scroll direction when adjusting the brush or eraser width with the scroll wheel.
|
||||
*/
|
||||
invertScrollForToolWidth: z.boolean().default(false),
|
||||
invertScrollForToolWidth: z.boolean(),
|
||||
/**
|
||||
* The width of the brush tool.
|
||||
*/
|
||||
brushWidth: z.int().gt(0).default(50),
|
||||
brushWidth: z.int().gt(0),
|
||||
/**
|
||||
* The width of the eraser tool.
|
||||
*/
|
||||
eraserWidth: z.int().gt(0).default(50),
|
||||
eraserWidth: z.int().gt(0),
|
||||
/**
|
||||
* The color to use when drawing lines or filling shapes.
|
||||
*/
|
||||
color: zRgbaColor.default({ r: 31, g: 160, b: 224, a: 1 }), // invokeBlue.500
|
||||
color: zRgbaColor,
|
||||
/**
|
||||
* Whether to composite inpainted/outpainted regions back onto the source image when saving canvas generations.
|
||||
*
|
||||
@ -44,57 +45,77 @@ const zCanvasSettingsState = z.object({
|
||||
*
|
||||
* When `sendToCanvas` is disabled, this setting is ignored, masked regions will always be composited.
|
||||
*/
|
||||
outputOnlyMaskedRegions: z.boolean().default(true),
|
||||
outputOnlyMaskedRegions: z.boolean(),
|
||||
/**
|
||||
* Whether to automatically process the operations like filtering and auto-masking.
|
||||
*/
|
||||
autoProcess: z.boolean().default(true),
|
||||
autoProcess: z.boolean(),
|
||||
/**
|
||||
* The snap-to-grid setting for the canvas.
|
||||
*/
|
||||
snapToGrid: z.boolean().default(true),
|
||||
snapToGrid: z.boolean(),
|
||||
/**
|
||||
* Whether to show progress on the canvas when generating images.
|
||||
*/
|
||||
showProgressOnCanvas: z.boolean().default(true),
|
||||
showProgressOnCanvas: z.boolean(),
|
||||
/**
|
||||
* Whether to show the bounding box overlay on the canvas.
|
||||
*/
|
||||
bboxOverlay: z.boolean().default(false),
|
||||
bboxOverlay: z.boolean(),
|
||||
/**
|
||||
* Whether to preserve the masked region instead of inpainting it.
|
||||
*/
|
||||
preserveMask: z.boolean().default(false),
|
||||
preserveMask: z.boolean(),
|
||||
/**
|
||||
* Whether to show only raster layers while staging.
|
||||
*/
|
||||
isolatedStagingPreview: z.boolean().default(true),
|
||||
isolatedStagingPreview: z.boolean(),
|
||||
/**
|
||||
* Whether to show only the selected layer while filtering, transforming, or doing other operations.
|
||||
*/
|
||||
isolatedLayerPreview: z.boolean().default(true),
|
||||
isolatedLayerPreview: z.boolean(),
|
||||
/**
|
||||
* Whether to use pressure sensitivity for the brush and eraser tool when a pen device is used.
|
||||
*/
|
||||
pressureSensitivity: z.boolean().default(true),
|
||||
pressureSensitivity: z.boolean(),
|
||||
/**
|
||||
* Whether to show the rule of thirds composition guide overlay on the canvas.
|
||||
*/
|
||||
ruleOfThirds: z.boolean().default(false),
|
||||
ruleOfThirds: z.boolean(),
|
||||
/**
|
||||
* Whether to save all staging images to the gallery instead of keeping them as intermediate images.
|
||||
*/
|
||||
saveAllImagesToGallery: z.boolean().default(false),
|
||||
saveAllImagesToGallery: z.boolean(),
|
||||
/**
|
||||
* The auto-switch mode for the canvas staging area.
|
||||
*/
|
||||
stagingAreaAutoSwitch: zAutoSwitchMode.default('switch_on_start'),
|
||||
stagingAreaAutoSwitch: zAutoSwitchMode,
|
||||
});
|
||||
|
||||
type CanvasSettingsState = z.infer<typeof zCanvasSettingsState>;
|
||||
const getInitialState = () => zCanvasSettingsState.parse({});
|
||||
const getInitialState = (): CanvasSettingsState => ({
|
||||
showHUD: true,
|
||||
clipToBbox: false,
|
||||
dynamicGrid: false,
|
||||
invertScrollForToolWidth: false,
|
||||
brushWidth: 50,
|
||||
eraserWidth: 50,
|
||||
color: { r: 31, g: 160, b: 224, a: 1 }, // invokeBlue.500
|
||||
outputOnlyMaskedRegions: true,
|
||||
autoProcess: true,
|
||||
snapToGrid: true,
|
||||
showProgressOnCanvas: true,
|
||||
bboxOverlay: false,
|
||||
preserveMask: false,
|
||||
isolatedStagingPreview: true,
|
||||
isolatedLayerPreview: true,
|
||||
pressureSensitivity: true,
|
||||
ruleOfThirds: false,
|
||||
saveAllImagesToGallery: false,
|
||||
stagingAreaAutoSwitch: 'switch_on_start',
|
||||
});
|
||||
|
||||
export const canvasSettingsSlice = createSlice({
|
||||
const slice = createSlice({
|
||||
name: 'canvasSettings',
|
||||
initialState: getInitialState(),
|
||||
reducers: {
|
||||
@ -184,18 +205,15 @@ export const {
|
||||
settingsRuleOfThirdsToggled,
|
||||
settingsSaveAllImagesToGalleryToggled,
|
||||
settingsStagingAreaAutoSwitchChanged,
|
||||
} = canvasSettingsSlice.actions;
|
||||
} = slice.actions;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrate = (state: any): any => {
|
||||
return state;
|
||||
};
|
||||
|
||||
export const canvasSettingsPersistConfig: PersistConfig<CanvasSettingsState> = {
|
||||
name: canvasSettingsSlice.name,
|
||||
initialState: getInitialState(),
|
||||
migrate,
|
||||
persistDenylist: [],
|
||||
export const canvasSettingsSliceConfig: SliceConfig<typeof slice> = {
|
||||
slice,
|
||||
schema: zCanvasSettingsState,
|
||||
getInitialState,
|
||||
persistConfig: {
|
||||
migrate: (state) => zCanvasSettingsState.parse(state),
|
||||
},
|
||||
};
|
||||
|
||||
export const selectCanvasSettingsSlice = (s: RootState) => s.canvasSettings;
|
||||
|
@ -1,6 +1,6 @@
|
||||
import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit';
|
||||
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig } from 'app/store/store';
|
||||
import type { SliceConfig } from 'app/store/types';
|
||||
import { moveOneToEnd, moveOneToStart, moveToEnd, moveToStart } from 'common/util/arrayUtils';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMultiple';
|
||||
@ -80,6 +80,7 @@ import {
|
||||
isFLUXReduxConfig,
|
||||
isImagenAspectRatioID,
|
||||
isIPAdapterConfig,
|
||||
zCanvasState,
|
||||
} from './types';
|
||||
import {
|
||||
converters,
|
||||
@ -95,7 +96,7 @@ import {
|
||||
initialT2IAdapter,
|
||||
} from './util';
|
||||
|
||||
export const canvasSlice = createSlice({
|
||||
const slice = createSlice({
|
||||
name: 'canvas',
|
||||
initialState: getInitialCanvasState(),
|
||||
reducers: {
|
||||
@ -1675,19 +1676,7 @@ export const {
|
||||
inpaintMaskDenoiseLimitChanged,
|
||||
inpaintMaskDenoiseLimitDeleted,
|
||||
// inpaintMaskRecalled,
|
||||
} = canvasSlice.actions;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrate = (state: any): any => {
|
||||
return state;
|
||||
};
|
||||
|
||||
export const canvasPersistConfig: PersistConfig<CanvasState> = {
|
||||
name: canvasSlice.name,
|
||||
initialState: getInitialCanvasState(),
|
||||
migrate,
|
||||
persistDenylist: [],
|
||||
};
|
||||
} = slice.actions;
|
||||
|
||||
const syncScaledSize = (state: CanvasState) => {
|
||||
if (API_BASE_MODELS.includes(state.bbox.modelBase)) {
|
||||
@ -1710,14 +1699,14 @@ const syncScaledSize = (state: CanvasState) => {
|
||||
|
||||
let filter = true;
|
||||
|
||||
export const canvasUndoableConfig: UndoableOptions<CanvasState, UnknownAction> = {
|
||||
const canvasUndoableConfig: UndoableOptions<CanvasState, UnknownAction> = {
|
||||
limit: 64,
|
||||
undoType: canvasUndo.type,
|
||||
redoType: canvasRedo.type,
|
||||
clearHistoryType: canvasClearHistory.type,
|
||||
filter: (action, _state, _history) => {
|
||||
// Ignore all actions from other slices
|
||||
if (!action.type.startsWith(canvasSlice.name)) {
|
||||
if (!action.type.startsWith(slice.name)) {
|
||||
return false;
|
||||
}
|
||||
// Throttle rapid actions of the same type
|
||||
@ -1728,6 +1717,18 @@ export const canvasUndoableConfig: UndoableOptions<CanvasState, UnknownAction> =
|
||||
// debug: import.meta.env.MODE === 'development',
|
||||
};
|
||||
|
||||
export const canvasSliceConfig: SliceConfig<typeof slice> = {
|
||||
slice,
|
||||
getInitialState: getInitialCanvasState,
|
||||
schema: zCanvasState,
|
||||
persistConfig: {
|
||||
migrate: (state) => zCanvasState.parse(state),
|
||||
},
|
||||
undoableConfig: {
|
||||
reduxUndoOptions: canvasUndoableConfig,
|
||||
},
|
||||
};
|
||||
|
||||
const doNotGroupMatcher = isAnyOf(entityBrushLineAdded, entityEraserLineAdded, entityRectAdded);
|
||||
|
||||
// Store rapid actions of the same type at most once every x time.
|
||||
|
@ -1,27 +1,29 @@
|
||||
import { createSelector, createSlice, type PayloadAction } from '@reduxjs/toolkit';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import type { SliceConfig } from 'app/store/types';
|
||||
import { isPlainObject } from 'es-toolkit';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { useMemo } from 'react';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import { assert } from 'tsafe';
|
||||
import z from 'zod';
|
||||
|
||||
type CanvasStagingAreaState = {
|
||||
_version: 1;
|
||||
canvasSessionId: string;
|
||||
canvasDiscardedQueueItems: number[];
|
||||
};
|
||||
const zCanvasStagingAreaState = z.object({
|
||||
_version: z.literal(1),
|
||||
canvasSessionId: z.string(),
|
||||
canvasDiscardedQueueItems: z.array(z.number().int()),
|
||||
});
|
||||
type CanvasStagingAreaState = z.infer<typeof zCanvasStagingAreaState>;
|
||||
|
||||
const INITIAL_STATE: CanvasStagingAreaState = {
|
||||
const getInitialState = (): CanvasStagingAreaState => ({
|
||||
_version: 1,
|
||||
canvasSessionId: getPrefixedId('canvas'),
|
||||
canvasDiscardedQueueItems: [],
|
||||
};
|
||||
});
|
||||
|
||||
const getInitialState = (): CanvasStagingAreaState => deepClone(INITIAL_STATE);
|
||||
|
||||
export const canvasSessionSlice = createSlice({
|
||||
const slice = createSlice({
|
||||
name: 'canvasSession',
|
||||
initialState: getInitialState(),
|
||||
reducers: {
|
||||
@ -48,26 +50,26 @@ export const canvasSessionSlice = createSlice({
|
||||
},
|
||||
});
|
||||
|
||||
export const { canvasSessionReset, canvasQueueItemDiscarded } = canvasSessionSlice.actions;
|
||||
export const { canvasSessionReset, canvasQueueItemDiscarded } = slice.actions;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrate = (state: any): any => {
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
state.canvasSessionId = state.canvasSessionId ?? getPrefixedId('canvas');
|
||||
}
|
||||
export const canvasSessionSliceConfig: SliceConfig<typeof slice> = {
|
||||
slice,
|
||||
schema: zCanvasStagingAreaState,
|
||||
getInitialState,
|
||||
persistConfig: {
|
||||
migrate: (state) => {
|
||||
assert(isPlainObject(state));
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
state.canvasSessionId = state.canvasSessionId ?? getPrefixedId('canvas');
|
||||
}
|
||||
|
||||
return state;
|
||||
return zCanvasStagingAreaState.parse(state);
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const canvasStagingAreaPersistConfig: PersistConfig<CanvasStagingAreaState> = {
|
||||
name: canvasSessionSlice.name,
|
||||
initialState: getInitialState(),
|
||||
migrate,
|
||||
persistDenylist: [],
|
||||
};
|
||||
|
||||
export const selectCanvasSessionSlice = (s: RootState) => s[canvasSessionSlice.name];
|
||||
export const selectCanvasSessionSlice = (s: RootState) => s[slice.name];
|
||||
export const selectCanvasSessionId = createSelector(selectCanvasSessionSlice, ({ canvasSessionId }) => canvasSessionId);
|
||||
|
||||
const selectDiscardedItems = createSelector(
|
||||
|
@ -166,7 +166,7 @@ const _zFilterConfig = z.discriminatedUnion('type', [
|
||||
]);
|
||||
export type FilterConfig = z.infer<typeof _zFilterConfig>;
|
||||
|
||||
const zFilterType = z.enum([
|
||||
export const zFilterType = z.enum([
|
||||
'adjust_image',
|
||||
'canny_edge_detection',
|
||||
'color_map',
|
||||
|
@ -1,30 +1,32 @@
|
||||
import { createSelector, createSlice, type PayloadAction } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { SliceConfig } from 'app/store/types';
|
||||
import { paramsReset } from 'features/controlLayers/store/paramsSlice';
|
||||
import type { LoRA } from 'features/controlLayers/store/types';
|
||||
import { type LoRA, zLoRA } from 'features/controlLayers/store/types';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import z from 'zod';
|
||||
|
||||
type LoRAsState = {
|
||||
loras: LoRA[];
|
||||
};
|
||||
const zLoRAsState = z.object({
|
||||
loras: z.array(zLoRA),
|
||||
});
|
||||
type LoRAsState = z.infer<typeof zLoRAsState>;
|
||||
|
||||
const defaultLoRAConfig: Pick<LoRA, 'weight' | 'isEnabled'> = {
|
||||
weight: 0.75,
|
||||
isEnabled: true,
|
||||
};
|
||||
|
||||
const initialState: LoRAsState = {
|
||||
const getInitialState = (): LoRAsState => ({
|
||||
loras: [],
|
||||
};
|
||||
});
|
||||
|
||||
const selectLoRA = (state: LoRAsState, id: string) => state.loras.find((lora) => lora.id === id);
|
||||
|
||||
export const lorasSlice = createSlice({
|
||||
const slice = createSlice({
|
||||
name: 'loras',
|
||||
initialState,
|
||||
initialState: getInitialState(),
|
||||
reducers: {
|
||||
loraAdded: {
|
||||
reducer: (state, action: PayloadAction<{ model: LoRAModelConfig; id: string }>) => {
|
||||
@ -66,24 +68,21 @@ export const lorasSlice = createSlice({
|
||||
extraReducers(builder) {
|
||||
builder.addCase(paramsReset, () => {
|
||||
// When a new session is requested, clear all LoRAs
|
||||
return deepClone(initialState);
|
||||
return getInitialState();
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const { loraAdded, loraRecalled, loraDeleted, loraWeightChanged, loraIsEnabledChanged, loraAllDeleted } =
|
||||
lorasSlice.actions;
|
||||
slice.actions;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrate = (state: any): any => {
|
||||
return state;
|
||||
};
|
||||
|
||||
export const lorasPersistConfig: PersistConfig<LoRAsState> = {
|
||||
name: lorasSlice.name,
|
||||
initialState,
|
||||
migrate,
|
||||
persistDenylist: [],
|
||||
export const lorasSliceConfig: SliceConfig<typeof slice> = {
|
||||
slice,
|
||||
schema: zLoRAsState,
|
||||
getInitialState,
|
||||
persistConfig: {
|
||||
migrate: (state) => zLoRAsState.parse(state),
|
||||
},
|
||||
};
|
||||
|
||||
export const selectLoRAsSlice = (state: RootState) => state.loras;
|
||||
|
@ -1,6 +1,7 @@
|
||||
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
|
||||
import { createSelector, createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { SliceConfig } from 'app/store/types';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMultiple';
|
||||
import { clamp } from 'es-toolkit/compat';
|
||||
@ -15,6 +16,7 @@ import {
|
||||
isChatGPT4oAspectRatioID,
|
||||
isFluxKontextAspectRatioID,
|
||||
isImagenAspectRatioID,
|
||||
zParamsState,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
|
||||
import { CLIP_SKIP_MAP } from 'features/parameters/types/constants';
|
||||
@ -40,7 +42,7 @@ import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/par
|
||||
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
|
||||
export const paramsSlice = createSlice({
|
||||
const slice = createSlice({
|
||||
name: 'params',
|
||||
initialState: getInitialParamsState(),
|
||||
reducers: {
|
||||
@ -92,7 +94,12 @@ export const paramsSlice = createSlice({
|
||||
state,
|
||||
action: PayloadAction<{ model: ParameterModel | null; previousModel?: ParameterModel | null }>
|
||||
) => {
|
||||
const { model, previousModel } = action.payload;
|
||||
const { previousModel } = action.payload;
|
||||
const result = zParamsState.shape.model.safeParse(action.payload.model);
|
||||
if (!result.success) {
|
||||
return;
|
||||
}
|
||||
const model = result.data;
|
||||
state.model = model;
|
||||
|
||||
// If the model base changes (e.g. SD1.5 -> SDXL), we need to change a few things
|
||||
@ -111,25 +118,53 @@ export const paramsSlice = createSlice({
|
||||
},
|
||||
vaeSelected: (state, action: PayloadAction<ParameterVAEModel | null>) => {
|
||||
// null is a valid VAE!
|
||||
state.vae = action.payload;
|
||||
const result = zParamsState.shape.vae.safeParse(action.payload);
|
||||
if (!result.success) {
|
||||
return;
|
||||
}
|
||||
state.vae = result.data;
|
||||
},
|
||||
fluxVAESelected: (state, action: PayloadAction<ParameterVAEModel | null>) => {
|
||||
state.fluxVAE = action.payload;
|
||||
const result = zParamsState.shape.fluxVAE.safeParse(action.payload);
|
||||
if (!result.success) {
|
||||
return;
|
||||
}
|
||||
state.fluxVAE = result.data;
|
||||
},
|
||||
t5EncoderModelSelected: (state, action: PayloadAction<ParameterT5EncoderModel | null>) => {
|
||||
state.t5EncoderModel = action.payload;
|
||||
const result = zParamsState.shape.t5EncoderModel.safeParse(action.payload);
|
||||
if (!result.success) {
|
||||
return;
|
||||
}
|
||||
state.t5EncoderModel = result.data;
|
||||
},
|
||||
controlLoRAModelSelected: (state, action: PayloadAction<ParameterControlLoRAModel | null>) => {
|
||||
state.controlLora = action.payload;
|
||||
const result = zParamsState.shape.controlLora.safeParse(action.payload);
|
||||
if (!result.success) {
|
||||
return;
|
||||
}
|
||||
state.controlLora = result.data;
|
||||
},
|
||||
clipEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPEmbedModel | null>) => {
|
||||
state.clipEmbedModel = action.payload;
|
||||
const result = zParamsState.shape.clipEmbedModel.safeParse(action.payload);
|
||||
if (!result.success) {
|
||||
return;
|
||||
}
|
||||
state.clipEmbedModel = result.data;
|
||||
},
|
||||
clipLEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPLEmbedModel | null>) => {
|
||||
state.clipLEmbedModel = action.payload;
|
||||
const result = zParamsState.shape.clipLEmbedModel.safeParse(action.payload);
|
||||
if (!result.success) {
|
||||
return;
|
||||
}
|
||||
state.clipLEmbedModel = result.data;
|
||||
},
|
||||
clipGEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPGEmbedModel | null>) => {
|
||||
state.clipGEmbedModel = action.payload;
|
||||
const result = zParamsState.shape.clipGEmbedModel.safeParse(action.payload);
|
||||
if (!result.success) {
|
||||
return;
|
||||
}
|
||||
state.clipGEmbedModel = result.data;
|
||||
},
|
||||
vaePrecisionChanged: (state, action: PayloadAction<ParameterPrecision>) => {
|
||||
state.vaePrecision = action.payload;
|
||||
@ -156,7 +191,11 @@ export const paramsSlice = createSlice({
|
||||
state.shouldConcatPrompts = action.payload;
|
||||
},
|
||||
refinerModelChanged: (state, action: PayloadAction<ParameterSDXLRefinerModel | null>) => {
|
||||
state.refinerModel = action.payload;
|
||||
const result = zParamsState.shape.refinerModel.safeParse(action.payload);
|
||||
if (!result.success) {
|
||||
return;
|
||||
}
|
||||
state.refinerModel = result.data;
|
||||
},
|
||||
setRefinerSteps: (state, action: PayloadAction<number>) => {
|
||||
state.refinerSteps = action.payload;
|
||||
@ -397,18 +436,15 @@ export const {
|
||||
syncedToOptimalDimension,
|
||||
|
||||
paramsReset,
|
||||
} = paramsSlice.actions;
|
||||
} = slice.actions;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrate = (state: any): any => {
|
||||
return state;
|
||||
};
|
||||
|
||||
export const paramsPersistConfig: PersistConfig<ParamsState> = {
|
||||
name: paramsSlice.name,
|
||||
initialState: getInitialParamsState(),
|
||||
migrate,
|
||||
persistDenylist: [],
|
||||
export const paramsSliceConfig: SliceConfig<typeof slice> = {
|
||||
slice,
|
||||
schema: zParamsState,
|
||||
getInitialState: getInitialParamsState,
|
||||
persistConfig: {
|
||||
migrate: (state) => zParamsState.parse(state),
|
||||
},
|
||||
};
|
||||
|
||||
export const selectParamsSlice = (state: RootState) => state.params;
|
||||
|
@ -2,7 +2,8 @@ import { objectEquals } from '@observ33r/object-equals';
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSelector, createSlice } from '@reduxjs/toolkit';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { SliceConfig } from 'app/store/types';
|
||||
import { clamp } from 'es-toolkit/compat';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import type { FLUXReduxImageInfluence, RefImagesState } from 'features/controlLayers/store/types';
|
||||
@ -18,7 +19,7 @@ import { assert } from 'tsafe';
|
||||
import type { PartialDeep } from 'type-fest';
|
||||
|
||||
import type { CLIPVisionModelV2, IPMethodV2, RefImageState } from './types';
|
||||
import { getInitialRefImagesState, isFLUXReduxConfig, isIPAdapterConfig } from './types';
|
||||
import { getInitialRefImagesState, isFLUXReduxConfig, isIPAdapterConfig, zRefImagesState } from './types';
|
||||
import {
|
||||
getReferenceImageState,
|
||||
imageDTOToImageWithDims,
|
||||
@ -36,7 +37,7 @@ type PayloadActionWithId<T = void> = T extends void
|
||||
} & T
|
||||
>;
|
||||
|
||||
export const refImagesSlice = createSlice({
|
||||
const slice = createSlice({
|
||||
name: 'refImages',
|
||||
initialState: getInitialRefImagesState(),
|
||||
reducers: {
|
||||
@ -263,18 +264,16 @@ export const {
|
||||
refImageFLUXReduxImageInfluenceChanged,
|
||||
refImageIsEnabledToggled,
|
||||
refImagesRecalled,
|
||||
} = refImagesSlice.actions;
|
||||
} = slice.actions;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrate = (state: any): any => {
|
||||
return state;
|
||||
};
|
||||
|
||||
export const refImagesPersistConfig: PersistConfig<RefImagesState> = {
|
||||
name: refImagesSlice.name,
|
||||
initialState: getInitialRefImagesState(),
|
||||
migrate,
|
||||
persistDenylist: ['selectedEntityId', 'isPanelOpen'],
|
||||
export const refImagesSliceConfig: SliceConfig<typeof slice> = {
|
||||
slice,
|
||||
schema: zRefImagesState,
|
||||
getInitialState: getInitialRefImagesState,
|
||||
persistConfig: {
|
||||
migrate: (state) => zRefImagesState.parse(state),
|
||||
persistDenylist: ['selectedEntityId', 'isPanelOpen'],
|
||||
},
|
||||
};
|
||||
|
||||
export const selectRefImagesSlice = (state: RootState) => state.refImages;
|
||||
|
@ -1,9 +1,7 @@
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types';
|
||||
import { fetchModelConfigByIdentifier } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import type { ProgressImage } from 'features/nodes/types/common';
|
||||
import { zMainModelBase, zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
||||
import {
|
||||
zParameterCanvasCoherenceMode,
|
||||
zParameterCFGRescaleMultiplier,
|
||||
@ -29,33 +27,17 @@ import {
|
||||
zParameterT5EncoderModel,
|
||||
zParameterVAEModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { getImageDTOSafe } from 'services/api/endpoints/images';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
import { z } from 'zod';
|
||||
|
||||
const zId = z.string().min(1);
|
||||
const zName = z.string().min(1).nullable();
|
||||
|
||||
const zServerValidatedModelIdentifierField = zModelIdentifierField.refine(async (modelIdentifier) => {
|
||||
try {
|
||||
await fetchModelConfigByIdentifier(modelIdentifier);
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
export const zImageWithDims = z.object({
|
||||
image_name: z.string(),
|
||||
width: z.number().int().positive(),
|
||||
height: z.number().int().positive(),
|
||||
});
|
||||
|
||||
const zImageWithDims = z
|
||||
.object({
|
||||
image_name: z.string(),
|
||||
width: z.number().int().positive(),
|
||||
height: z.number().int().positive(),
|
||||
})
|
||||
.refine(async (v) => {
|
||||
const { image_name } = v;
|
||||
const imageDTO = await getImageDTOSafe(image_name, { forceRefetch: true });
|
||||
return imageDTO !== null;
|
||||
});
|
||||
export type ImageWithDims = z.infer<typeof zImageWithDims>;
|
||||
|
||||
const zImageWithDimsDataURL = z.object({
|
||||
@ -253,7 +235,7 @@ export type CanvasObjectState = z.infer<typeof zCanvasObjectState>;
|
||||
const zIPAdapterConfig = z.object({
|
||||
type: z.literal('ip_adapter'),
|
||||
image: zImageWithDims.nullable(),
|
||||
model: zServerValidatedModelIdentifierField.nullable(),
|
||||
model: zModelIdentifierField.nullable(),
|
||||
weight: z.number().gte(-1).lte(2),
|
||||
beginEndStepPct: zBeginEndStepPct,
|
||||
method: zIPMethodV2,
|
||||
@ -268,7 +250,7 @@ export type FLUXReduxImageInfluence = z.infer<typeof zFLUXReduxImageInfluence>;
|
||||
const zFLUXReduxConfig = z.object({
|
||||
type: z.literal('flux_redux'),
|
||||
image: zImageWithDims.nullable(),
|
||||
model: zServerValidatedModelIdentifierField.nullable(),
|
||||
model: zModelIdentifierField.nullable(),
|
||||
imageInfluence: zFLUXReduxImageInfluence.default('highest'),
|
||||
});
|
||||
export type FLUXReduxConfig = z.infer<typeof zFLUXReduxConfig>;
|
||||
@ -281,14 +263,14 @@ const zChatGPT4oReferenceImageConfig = z.object({
|
||||
* But we use a model drop down to switch between different ref image types, so there needs to be a model here else
|
||||
* there will be no way to switch between ref image types.
|
||||
*/
|
||||
model: zServerValidatedModelIdentifierField.nullable(),
|
||||
model: zModelIdentifierField.nullable(),
|
||||
});
|
||||
export type ChatGPT4oReferenceImageConfig = z.infer<typeof zChatGPT4oReferenceImageConfig>;
|
||||
|
||||
const zFluxKontextReferenceImageConfig = z.object({
|
||||
type: z.literal('flux_kontext_reference_image'),
|
||||
image: zImageWithDims.nullable(),
|
||||
model: zServerValidatedModelIdentifierField.nullable(),
|
||||
model: zModelIdentifierField.nullable(),
|
||||
});
|
||||
export type FluxKontextReferenceImageConfig = z.infer<typeof zFluxKontextReferenceImageConfig>;
|
||||
|
||||
@ -360,7 +342,7 @@ export type CanvasInpaintMaskState = z.infer<typeof zCanvasInpaintMaskState>;
|
||||
|
||||
const zControlNetConfig = z.object({
|
||||
type: z.literal('controlnet'),
|
||||
model: zServerValidatedModelIdentifierField.nullable(),
|
||||
model: zModelIdentifierField.nullable(),
|
||||
weight: z.number().gte(-1).lte(2),
|
||||
beginEndStepPct: zBeginEndStepPct,
|
||||
controlMode: zControlModeV2,
|
||||
@ -369,7 +351,7 @@ export type ControlNetConfig = z.infer<typeof zControlNetConfig>;
|
||||
|
||||
const zT2IAdapterConfig = z.object({
|
||||
type: z.literal('t2i_adapter'),
|
||||
model: zServerValidatedModelIdentifierField.nullable(),
|
||||
model: zModelIdentifierField.nullable(),
|
||||
weight: z.number().gte(-1).lte(2),
|
||||
beginEndStepPct: zBeginEndStepPct,
|
||||
});
|
||||
@ -378,7 +360,7 @@ export type T2IAdapterConfig = z.infer<typeof zT2IAdapterConfig>;
|
||||
const zControlLoRAConfig = z.object({
|
||||
type: z.literal('control_lora'),
|
||||
weight: z.number().gte(-1).lte(2),
|
||||
model: zServerValidatedModelIdentifierField.nullable(),
|
||||
model: zModelIdentifierField.nullable(),
|
||||
});
|
||||
export type ControlLoRAConfig = z.infer<typeof zControlLoRAConfig>;
|
||||
|
||||
@ -424,12 +406,13 @@ export const zCanvasEntityIdentifer = z.object({
|
||||
});
|
||||
export type CanvasEntityIdentifier<T extends CanvasEntityType = CanvasEntityType> = { id: string; type: T };
|
||||
|
||||
export type LoRA = {
|
||||
id: string;
|
||||
isEnabled: boolean;
|
||||
model: ParameterLoRAModel;
|
||||
weight: number;
|
||||
};
|
||||
export const zLoRA = z.object({
|
||||
id: z.string(),
|
||||
isEnabled: z.boolean(),
|
||||
model: zModelIdentifierField,
|
||||
weight: z.number().gte(-1).lte(2),
|
||||
});
|
||||
export type LoRA = z.infer<typeof zLoRA>;
|
||||
|
||||
export type EphemeralProgressImage = { sessionId: string; image: ProgressImage };
|
||||
|
||||
@ -522,62 +505,108 @@ const zDimensionsState = z.object({
|
||||
aspectRatio: zAspectRatioConfig,
|
||||
});
|
||||
|
||||
const zParamsState = z.object({
|
||||
maskBlur: z.number().default(16),
|
||||
maskBlurMethod: zParameterMaskBlurMethod.default('box'),
|
||||
canvasCoherenceMode: zParameterCanvasCoherenceMode.default('Gaussian Blur'),
|
||||
canvasCoherenceMinDenoise: zParameterStrength.default(0),
|
||||
canvasCoherenceEdgeSize: z.number().default(16),
|
||||
infillMethod: z.string().default('lama'),
|
||||
infillTileSize: z.number().default(32),
|
||||
infillPatchmatchDownscaleSize: z.number().default(1),
|
||||
infillColorValue: zRgbaColor.default({ r: 0, g: 0, b: 0, a: 1 }),
|
||||
cfgScale: zParameterCFGScale.default(7.5),
|
||||
cfgRescaleMultiplier: zParameterCFGRescaleMultiplier.default(0),
|
||||
guidance: zParameterGuidance.default(4),
|
||||
img2imgStrength: zParameterStrength.default(0.75),
|
||||
optimizedDenoisingEnabled: z.boolean().default(true),
|
||||
iterations: z.number().default(1),
|
||||
scheduler: zParameterScheduler.default('dpmpp_3m_k'),
|
||||
upscaleScheduler: zParameterScheduler.default('kdpm_2'),
|
||||
upscaleCfgScale: zParameterCFGScale.default(2),
|
||||
seed: zParameterSeed.default(0),
|
||||
shouldRandomizeSeed: z.boolean().default(true),
|
||||
steps: zParameterSteps.default(30),
|
||||
model: zParameterModel.nullable().default(null),
|
||||
vae: zParameterVAEModel.nullable().default(null),
|
||||
vaePrecision: zParameterPrecision.default('fp32'),
|
||||
fluxVAE: zParameterVAEModel.nullable().default(null),
|
||||
seamlessXAxis: z.boolean().default(false),
|
||||
seamlessYAxis: z.boolean().default(false),
|
||||
clipSkip: z.number().default(0),
|
||||
shouldUseCpuNoise: z.boolean().default(true),
|
||||
positivePrompt: zParameterPositivePrompt.default(''),
|
||||
// Negative prompt may be disabled, in which case it will be null
|
||||
negativePrompt: zParameterNegativePrompt.default(null),
|
||||
positivePrompt2: zParameterPositiveStylePromptSDXL.default(''),
|
||||
negativePrompt2: zParameterNegativeStylePromptSDXL.default(''),
|
||||
shouldConcatPrompts: z.boolean().default(true),
|
||||
refinerModel: zParameterSDXLRefinerModel.nullable().default(null),
|
||||
refinerSteps: z.number().default(20),
|
||||
refinerCFGScale: z.number().default(7.5),
|
||||
refinerScheduler: zParameterScheduler.default('euler'),
|
||||
refinerPositiveAestheticScore: z.number().default(6),
|
||||
refinerNegativeAestheticScore: z.number().default(2.5),
|
||||
refinerStart: z.number().default(0.8),
|
||||
t5EncoderModel: zParameterT5EncoderModel.nullable().default(null),
|
||||
clipEmbedModel: zParameterCLIPEmbedModel.nullable().default(null),
|
||||
clipLEmbedModel: zParameterCLIPLEmbedModel.nullable().default(null),
|
||||
clipGEmbedModel: zParameterCLIPGEmbedModel.nullable().default(null),
|
||||
controlLora: zParameterControlLoRAModel.nullable().default(null),
|
||||
dimensions: zDimensionsState.default({
|
||||
rect: { x: 0, y: 0, width: 512, height: 512 },
|
||||
aspectRatio: DEFAULT_ASPECT_RATIO_CONFIG,
|
||||
}),
|
||||
export const zParamsState = z.object({
|
||||
maskBlur: z.number(),
|
||||
maskBlurMethod: zParameterMaskBlurMethod,
|
||||
canvasCoherenceMode: zParameterCanvasCoherenceMode,
|
||||
canvasCoherenceMinDenoise: zParameterStrength,
|
||||
canvasCoherenceEdgeSize: z.number(),
|
||||
infillMethod: z.string(),
|
||||
infillTileSize: z.number(),
|
||||
infillPatchmatchDownscaleSize: z.number(),
|
||||
infillColorValue: zRgbaColor,
|
||||
cfgScale: zParameterCFGScale,
|
||||
cfgRescaleMultiplier: zParameterCFGRescaleMultiplier,
|
||||
guidance: zParameterGuidance,
|
||||
img2imgStrength: zParameterStrength,
|
||||
optimizedDenoisingEnabled: z.boolean(),
|
||||
iterations: z.number(),
|
||||
scheduler: zParameterScheduler,
|
||||
upscaleScheduler: zParameterScheduler,
|
||||
upscaleCfgScale: zParameterCFGScale,
|
||||
seed: zParameterSeed,
|
||||
shouldRandomizeSeed: z.boolean(),
|
||||
steps: zParameterSteps,
|
||||
model: zParameterModel.nullable(),
|
||||
vae: zParameterVAEModel.nullable(),
|
||||
vaePrecision: zParameterPrecision,
|
||||
fluxVAE: zParameterVAEModel.nullable(),
|
||||
seamlessXAxis: z.boolean(),
|
||||
seamlessYAxis: z.boolean(),
|
||||
clipSkip: z.number(),
|
||||
shouldUseCpuNoise: z.boolean(),
|
||||
positivePrompt: zParameterPositivePrompt,
|
||||
negativePrompt: zParameterNegativePrompt,
|
||||
positivePrompt2: zParameterPositiveStylePromptSDXL,
|
||||
negativePrompt2: zParameterNegativeStylePromptSDXL,
|
||||
shouldConcatPrompts: z.boolean(),
|
||||
refinerModel: zParameterSDXLRefinerModel.nullable(),
|
||||
refinerSteps: z.number(),
|
||||
refinerCFGScale: z.number(),
|
||||
refinerScheduler: zParameterScheduler,
|
||||
refinerPositiveAestheticScore: z.number(),
|
||||
refinerNegativeAestheticScore: z.number(),
|
||||
refinerStart: z.number(),
|
||||
t5EncoderModel: zParameterT5EncoderModel.nullable(),
|
||||
clipEmbedModel: zParameterCLIPEmbedModel.nullable(),
|
||||
clipLEmbedModel: zParameterCLIPLEmbedModel.nullable(),
|
||||
clipGEmbedModel: zParameterCLIPGEmbedModel.nullable(),
|
||||
controlLora: zParameterControlLoRAModel.nullable(),
|
||||
dimensions: zDimensionsState,
|
||||
});
|
||||
export type ParamsState = z.infer<typeof zParamsState>;
|
||||
const INITIAL_PARAMS_STATE = zParamsState.parse({});
|
||||
export const getInitialParamsState = () => deepClone(INITIAL_PARAMS_STATE);
|
||||
export const getInitialParamsState = (): ParamsState => ({
|
||||
maskBlur: 16,
|
||||
maskBlurMethod: 'box',
|
||||
canvasCoherenceMode: 'Gaussian Blur',
|
||||
canvasCoherenceMinDenoise: 0,
|
||||
canvasCoherenceEdgeSize: 16,
|
||||
infillMethod: 'lama',
|
||||
infillTileSize: 32,
|
||||
infillPatchmatchDownscaleSize: 1,
|
||||
infillColorValue: { r: 0, g: 0, b: 0, a: 1 },
|
||||
cfgScale: 7.5,
|
||||
cfgRescaleMultiplier: 0,
|
||||
guidance: 4,
|
||||
img2imgStrength: 0.75,
|
||||
optimizedDenoisingEnabled: true,
|
||||
iterations: 1,
|
||||
scheduler: 'dpmpp_3m_k',
|
||||
upscaleScheduler: 'kdpm_2',
|
||||
upscaleCfgScale: 2,
|
||||
seed: 0,
|
||||
shouldRandomizeSeed: true,
|
||||
steps: 30,
|
||||
model: null,
|
||||
vae: null,
|
||||
vaePrecision: 'fp32',
|
||||
fluxVAE: null,
|
||||
seamlessXAxis: false,
|
||||
seamlessYAxis: false,
|
||||
clipSkip: 0,
|
||||
shouldUseCpuNoise: true,
|
||||
positivePrompt: '',
|
||||
negativePrompt: null,
|
||||
positivePrompt2: '',
|
||||
negativePrompt2: '',
|
||||
shouldConcatPrompts: true,
|
||||
refinerModel: null,
|
||||
refinerSteps: 20,
|
||||
refinerCFGScale: 7.5,
|
||||
refinerScheduler: 'euler',
|
||||
refinerPositiveAestheticScore: 6,
|
||||
refinerNegativeAestheticScore: 2.5,
|
||||
refinerStart: 0.8,
|
||||
t5EncoderModel: null,
|
||||
clipEmbedModel: null,
|
||||
clipLEmbedModel: null,
|
||||
clipGEmbedModel: null,
|
||||
controlLora: null,
|
||||
dimensions: {
|
||||
rect: { x: 0, y: 0, width: 512, height: 512 },
|
||||
aspectRatio: deepClone(DEFAULT_ASPECT_RATIO_CONFIG),
|
||||
},
|
||||
});
|
||||
|
||||
const zInpaintMasks = z.object({
|
||||
isHidden: z.boolean(),
|
||||
@ -595,38 +624,45 @@ const zRegionalGuidance = z.object({
|
||||
isHidden: z.boolean(),
|
||||
entities: z.array(zCanvasRegionalGuidanceState),
|
||||
});
|
||||
const zCanvasState = z.object({
|
||||
_version: z.literal(3).default(3),
|
||||
selectedEntityIdentifier: zCanvasEntityIdentifer.nullable().default(null),
|
||||
bookmarkedEntityIdentifier: zCanvasEntityIdentifer.nullable().default(null),
|
||||
inpaintMasks: zInpaintMasks.default({ isHidden: false, entities: [] }),
|
||||
rasterLayers: zRasterLayers.default({ isHidden: false, entities: [] }),
|
||||
controlLayers: zControlLayers.default({ isHidden: false, entities: [] }),
|
||||
regionalGuidance: zRegionalGuidance.default({ isHidden: false, entities: [] }),
|
||||
bbox: zBboxState.default({
|
||||
export const zCanvasState = z.object({
|
||||
_version: z.literal(3),
|
||||
selectedEntityIdentifier: zCanvasEntityIdentifer.nullable(),
|
||||
bookmarkedEntityIdentifier: zCanvasEntityIdentifer.nullable(),
|
||||
inpaintMasks: zInpaintMasks,
|
||||
rasterLayers: zRasterLayers,
|
||||
controlLayers: zControlLayers,
|
||||
regionalGuidance: zRegionalGuidance,
|
||||
bbox: zBboxState,
|
||||
});
|
||||
export type CanvasState = z.infer<typeof zCanvasState>;
|
||||
export const getInitialCanvasState = (): CanvasState => ({
|
||||
_version: 3,
|
||||
selectedEntityIdentifier: null,
|
||||
bookmarkedEntityIdentifier: null,
|
||||
inpaintMasks: { isHidden: false, entities: [] },
|
||||
rasterLayers: { isHidden: false, entities: [] },
|
||||
controlLayers: { isHidden: false, entities: [] },
|
||||
regionalGuidance: { isHidden: false, entities: [] },
|
||||
bbox: {
|
||||
rect: { x: 0, y: 0, width: 512, height: 512 },
|
||||
aspectRatio: DEFAULT_ASPECT_RATIO_CONFIG,
|
||||
aspectRatio: deepClone(DEFAULT_ASPECT_RATIO_CONFIG),
|
||||
scaleMethod: 'auto',
|
||||
scaledSize: { width: 512, height: 512 },
|
||||
modelBase: 'sd-1',
|
||||
}),
|
||||
},
|
||||
});
|
||||
export type CanvasState = z.infer<typeof zCanvasState>;
|
||||
|
||||
const zRefImagesState = z.object({
|
||||
selectedEntityId: z.string().nullable().default(null),
|
||||
isPanelOpen: z.boolean().default(false),
|
||||
entities: z.array(zRefImageState).default(() => []),
|
||||
export const zRefImagesState = z.object({
|
||||
selectedEntityId: z.string().nullable(),
|
||||
isPanelOpen: z.boolean(),
|
||||
entities: z.array(zRefImageState),
|
||||
});
|
||||
export type RefImagesState = z.infer<typeof zRefImagesState>;
|
||||
const INITIAL_REF_IMAGES_STATE = zRefImagesState.parse({});
|
||||
export const getInitialRefImagesState = () => deepClone(INITIAL_REF_IMAGES_STATE);
|
||||
|
||||
/**
|
||||
* Gets a fresh canvas initial state with no references in memory to existing objects.
|
||||
*/
|
||||
const CANVAS_INITIAL_STATE = zCanvasState.parse({});
|
||||
export const getInitialCanvasState = () => deepClone(CANVAS_INITIAL_STATE);
|
||||
export const getInitialRefImagesState = (): RefImagesState => ({
|
||||
selectedEntityId: null,
|
||||
isPanelOpen: false,
|
||||
entities: [],
|
||||
});
|
||||
|
||||
export const zCanvasReferenceImageState_OLD = zCanvasEntityBase.extend({
|
||||
type: z.literal('reference_image'),
|
||||
|
@ -1,25 +1,29 @@
|
||||
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
|
||||
import { createSelector, createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { SliceConfig } from 'app/store/types';
|
||||
import { buildZodTypeGuard } from 'common/util/zodUtils';
|
||||
import { isPlainObject } from 'es-toolkit';
|
||||
import { assert } from 'tsafe';
|
||||
import { z } from 'zod';
|
||||
|
||||
const zSeedBehaviour = z.enum(['PER_ITERATION', 'PER_PROMPT']);
|
||||
export const isSeedBehaviour = buildZodTypeGuard(zSeedBehaviour);
|
||||
export type SeedBehaviour = z.infer<typeof zSeedBehaviour>;
|
||||
|
||||
export interface DynamicPromptsState {
|
||||
_version: 1;
|
||||
maxPrompts: number;
|
||||
combinatorial: boolean;
|
||||
prompts: string[];
|
||||
parsingError: string | undefined | null;
|
||||
isError: boolean;
|
||||
isLoading: boolean;
|
||||
seedBehaviour: SeedBehaviour;
|
||||
}
|
||||
const zDynamicPromptsState = z.object({
|
||||
_version: z.literal(1),
|
||||
maxPrompts: z.number().int().min(1).max(1000),
|
||||
combinatorial: z.boolean(),
|
||||
prompts: z.array(z.string()),
|
||||
parsingError: z.string().nullish(),
|
||||
isError: z.boolean(),
|
||||
isLoading: z.boolean(),
|
||||
seedBehaviour: zSeedBehaviour,
|
||||
});
|
||||
export type DynamicPromptsState = z.infer<typeof zDynamicPromptsState>;
|
||||
|
||||
const initialDynamicPromptsState: DynamicPromptsState = {
|
||||
const getInitialState = (): DynamicPromptsState => ({
|
||||
_version: 1,
|
||||
maxPrompts: 100,
|
||||
combinatorial: true,
|
||||
@ -28,11 +32,11 @@ const initialDynamicPromptsState: DynamicPromptsState = {
|
||||
isError: false,
|
||||
isLoading: false,
|
||||
seedBehaviour: 'PER_ITERATION',
|
||||
};
|
||||
});
|
||||
|
||||
export const dynamicPromptsSlice = createSlice({
|
||||
const slice = createSlice({
|
||||
name: 'dynamicPrompts',
|
||||
initialState: initialDynamicPromptsState,
|
||||
initialState: getInitialState(),
|
||||
reducers: {
|
||||
maxPromptsChanged: (state, action: PayloadAction<number>) => {
|
||||
state.maxPrompts = action.payload;
|
||||
@ -63,21 +67,22 @@ export const {
|
||||
isErrorChanged,
|
||||
isLoadingChanged,
|
||||
seedBehaviourChanged,
|
||||
} = dynamicPromptsSlice.actions;
|
||||
} = slice.actions;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrateDynamicPromptsState = (state: any): any => {
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
||||
export const dynamicPromptsPersistConfig: PersistConfig<DynamicPromptsState> = {
|
||||
name: dynamicPromptsSlice.name,
|
||||
initialState: initialDynamicPromptsState,
|
||||
migrate: migrateDynamicPromptsState,
|
||||
persistDenylist: ['prompts'],
|
||||
export const dynamicPromptsSliceConfig: SliceConfig<typeof slice> = {
|
||||
slice,
|
||||
schema: zDynamicPromptsState,
|
||||
getInitialState,
|
||||
persistConfig: {
|
||||
migrate: (state) => {
|
||||
assert(isPlainObject(state));
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
return zDynamicPromptsState.parse(state);
|
||||
},
|
||||
persistDenylist: ['prompts', 'parsingError', 'isError', 'isLoading'],
|
||||
},
|
||||
};
|
||||
|
||||
export const selectDynamicPromptsSlice = (state: RootState) => state.dynamicPrompts;
|
||||
|
@ -1,5 +1,6 @@
|
||||
import { MenuItem } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
|
||||
import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
|
||||
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
@ -14,7 +15,7 @@ export const ImageMenuItemSendToUpscale = memo(() => {
|
||||
const imageDTO = useImageDTOContext();
|
||||
|
||||
const handleSendToCanvas = useCallback(() => {
|
||||
dispatch(upscaleInitialImageChanged(imageDTO));
|
||||
dispatch(upscaleInitialImageChanged(imageDTOToImageWithDims(imageDTO)));
|
||||
navigationApi.switchToTab('upscaling');
|
||||
toast({
|
||||
id: 'SENT_TO_CANVAS',
|
||||
|
@ -1,13 +1,23 @@
|
||||
import { objectEquals } from '@observ33r/object-equals';
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { SliceConfig } from 'app/store/types';
|
||||
import { isPlainObject } from 'es-toolkit';
|
||||
import { uniq } from 'es-toolkit/compat';
|
||||
import type { BoardRecordOrderBy } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import type { BoardId, ComparisonMode, GalleryState, GalleryView, OrderDir } from './types';
|
||||
import {
|
||||
type BoardId,
|
||||
type ComparisonMode,
|
||||
type GalleryState,
|
||||
type GalleryView,
|
||||
type OrderDir,
|
||||
zGalleryState,
|
||||
} from './types';
|
||||
|
||||
const initialGalleryState: GalleryState = {
|
||||
const getInitialState = (): GalleryState => ({
|
||||
selection: [],
|
||||
shouldAutoSwitch: true,
|
||||
autoAssignBoardOnClick: true,
|
||||
@ -26,11 +36,11 @@ const initialGalleryState: GalleryState = {
|
||||
shouldShowArchivedBoards: false,
|
||||
boardsListOrderBy: 'created_at',
|
||||
boardsListOrderDir: 'DESC',
|
||||
};
|
||||
});
|
||||
|
||||
export const gallerySlice = createSlice({
|
||||
const slice = createSlice({
|
||||
name: 'gallery',
|
||||
initialState: initialGalleryState,
|
||||
initialState: getInitialState(),
|
||||
reducers: {
|
||||
imageSelected: (state, action: PayloadAction<string | null>) => {
|
||||
// Let's be efficient here and not update the selection unless it has actually changed. This helps to prevent
|
||||
@ -187,21 +197,22 @@ export const {
|
||||
searchTermChanged,
|
||||
boardsListOrderByChanged,
|
||||
boardsListOrderDirChanged,
|
||||
} = gallerySlice.actions;
|
||||
} = slice.actions;
|
||||
|
||||
export const selectGallerySlice = (state: RootState) => state.gallery;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrateGalleryState = (state: any): any => {
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
||||
export const galleryPersistConfig: PersistConfig<GalleryState> = {
|
||||
name: gallerySlice.name,
|
||||
initialState: initialGalleryState,
|
||||
migrate: migrateGalleryState,
|
||||
persistDenylist: ['selection', 'selectedBoardId', 'galleryView', 'imageToCompare'],
|
||||
export const gallerySliceConfig: SliceConfig<typeof slice> = {
|
||||
slice,
|
||||
schema: zGalleryState,
|
||||
getInitialState,
|
||||
persistConfig: {
|
||||
migrate: (state) => {
|
||||
assert(isPlainObject(state));
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
return zGalleryState.parse(state);
|
||||
},
|
||||
persistDenylist: ['selection', 'selectedBoardId', 'galleryView', 'imageToCompare'],
|
||||
},
|
||||
};
|
||||
|
@ -0,0 +1,13 @@
|
||||
import type { S } from 'services/api/types';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
import { describe, test } from 'vitest';
|
||||
|
||||
import type { BoardRecordOrderBy } from './types';
|
||||
|
||||
describe('Gallery Types', () => {
|
||||
// Ensure zod types match OpenAPI types
|
||||
test('BoardRecordOrderBy', () => {
|
||||
assert<Equals<BoardRecordOrderBy, S['BoardRecordOrderBy']>>();
|
||||
});
|
||||
});
|
@ -1,31 +1,41 @@
|
||||
import type { BoardRecordOrderBy, ImageCategory } from 'services/api/types';
|
||||
import type { ImageCategory } from 'services/api/types';
|
||||
import z from 'zod';
|
||||
|
||||
const zGalleryView = z.enum(['images', 'assets']);
|
||||
export type GalleryView = z.infer<typeof zGalleryView>;
|
||||
const zBoardId = z.union([z.literal('none'), z.intersection(z.string(), z.record(z.never(), z.never()))]);
|
||||
export type BoardId = z.infer<typeof zBoardId>;
|
||||
const zComparisonMode = z.enum(['slider', 'side-by-side', 'hover']);
|
||||
export type ComparisonMode = z.infer<typeof zComparisonMode>;
|
||||
const zComparisonFit = z.enum(['contain', 'fill']);
|
||||
export type ComparisonFit = z.infer<typeof zComparisonFit>;
|
||||
const zOrderDir = z.enum(['ASC', 'DESC']);
|
||||
export type OrderDir = z.infer<typeof zOrderDir>;
|
||||
const zBoardRecordOrderBy = z.enum(['created_at', 'board_name']);
|
||||
export type BoardRecordOrderBy = z.infer<typeof zBoardRecordOrderBy>;
|
||||
|
||||
export const IMAGE_CATEGORIES: ImageCategory[] = ['general'];
|
||||
export const ASSETS_CATEGORIES: ImageCategory[] = ['control', 'mask', 'user', 'other'];
|
||||
|
||||
export type GalleryView = 'images' | 'assets';
|
||||
export type BoardId = 'none' | (string & Record<never, never>);
|
||||
export type ComparisonMode = 'slider' | 'side-by-side' | 'hover';
|
||||
export type ComparisonFit = 'contain' | 'fill';
|
||||
export type OrderDir = 'ASC' | 'DESC';
|
||||
export const zGalleryState = z.object({
|
||||
selection: z.array(z.string()),
|
||||
shouldAutoSwitch: z.boolean(),
|
||||
autoAssignBoardOnClick: z.boolean(),
|
||||
autoAddBoardId: zBoardId,
|
||||
galleryImageMinimumWidth: z.number(),
|
||||
selectedBoardId: zBoardId,
|
||||
galleryView: zGalleryView,
|
||||
boardSearchText: z.string(),
|
||||
starredFirst: z.boolean(),
|
||||
orderDir: zOrderDir,
|
||||
searchTerm: z.string(),
|
||||
alwaysShowImageSizeBadge: z.boolean(),
|
||||
imageToCompare: z.string().nullable(),
|
||||
comparisonMode: zComparisonMode,
|
||||
comparisonFit: zComparisonFit,
|
||||
shouldShowArchivedBoards: z.boolean(),
|
||||
boardsListOrderBy: zBoardRecordOrderBy,
|
||||
boardsListOrderDir: zOrderDir,
|
||||
});
|
||||
|
||||
export type GalleryState = {
|
||||
selection: string[];
|
||||
shouldAutoSwitch: boolean;
|
||||
autoAssignBoardOnClick: boolean;
|
||||
autoAddBoardId: BoardId;
|
||||
galleryImageMinimumWidth: number;
|
||||
selectedBoardId: BoardId;
|
||||
galleryView: GalleryView;
|
||||
boardSearchText: string;
|
||||
starredFirst: boolean;
|
||||
orderDir: OrderDir;
|
||||
searchTerm: string;
|
||||
alwaysShowImageSizeBadge: boolean;
|
||||
imageToCompare: string | null;
|
||||
comparisonMode: ComparisonMode;
|
||||
comparisonFit: ComparisonFit;
|
||||
shouldShowArchivedBoards: boolean;
|
||||
boardsListOrderBy: BoardRecordOrderBy;
|
||||
boardsListOrderDir: OrderDir;
|
||||
};
|
||||
export type GalleryState = z.infer<typeof zGalleryState>;
|
||||
|
@ -58,7 +58,7 @@ export const setRegionalGuidanceReferenceImage = (arg: {
|
||||
|
||||
export const setUpscaleInitialImage = (arg: { imageDTO: ImageDTO; dispatch: AppDispatch }) => {
|
||||
const { imageDTO, dispatch } = arg;
|
||||
dispatch(upscaleInitialImageChanged(imageDTO));
|
||||
dispatch(upscaleInitialImageChanged(imageDTOToImageWithDims(imageDTO)));
|
||||
};
|
||||
|
||||
export const setNodeImageFieldImage = (arg: {
|
||||
|
@ -89,6 +89,7 @@ import { t } from 'i18next';
|
||||
import type { ComponentType } from 'react';
|
||||
import { useCallback, useEffect, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig, ModelType } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
@ -787,11 +788,55 @@ const LoRAs: CollectionMetadataHandler<LoRA[]> = {
|
||||
const CanvasLayers: SingleMetadataHandler<CanvasMetadata> = {
|
||||
[SingleMetadataKey]: true,
|
||||
type: 'CanvasLayers',
|
||||
parse: async (metadata) => {
|
||||
parse: async (metadata, store) => {
|
||||
const raw = getProperty(metadata, 'canvas_v2_metadata');
|
||||
// This validator fetches all referenced images. If any do not exist, validation fails. The logic for this is in
|
||||
// the zImageWithDims schema.
|
||||
const parsed = await zCanvasMetadata.parseAsync(raw);
|
||||
|
||||
for (const entity of parsed.controlLayers) {
|
||||
if (entity.controlAdapter.model) {
|
||||
await throwIfModelDoesNotExist(entity.controlAdapter.model.key, store);
|
||||
}
|
||||
for (const object of entity.objects) {
|
||||
if (object.type === 'image' && 'image_name' in object.image) {
|
||||
await throwIfImageDoesNotExist(object.image.image_name, store);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const entity of parsed.inpaintMasks) {
|
||||
for (const object of entity.objects) {
|
||||
if (object.type === 'image' && 'image_name' in object.image) {
|
||||
await throwIfImageDoesNotExist(object.image.image_name, store);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const entity of parsed.rasterLayers) {
|
||||
for (const object of entity.objects) {
|
||||
if (object.type === 'image' && 'image_name' in object.image) {
|
||||
await throwIfImageDoesNotExist(object.image.image_name, store);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const entity of parsed.regionalGuidance) {
|
||||
for (const object of entity.objects) {
|
||||
if (object.type === 'image' && 'image_name' in object.image) {
|
||||
await throwIfImageDoesNotExist(object.image.image_name, store);
|
||||
}
|
||||
}
|
||||
for (const refImage of entity.referenceImages) {
|
||||
if (refImage.config.image) {
|
||||
await throwIfImageDoesNotExist(refImage.config.image.image_name, store);
|
||||
}
|
||||
if (refImage.config.model) {
|
||||
await throwIfModelDoesNotExist(refImage.config.model.key, store);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Promise.resolve(parsed);
|
||||
},
|
||||
recall: (value, store) => {
|
||||
@ -824,27 +869,39 @@ const CanvasLayers: SingleMetadataHandler<CanvasMetadata> = {
|
||||
const RefImages: CollectionMetadataHandler<RefImageState[]> = {
|
||||
[CollectionMetadataKey]: true,
|
||||
type: 'RefImages',
|
||||
parse: async (metadata) => {
|
||||
parse: async (metadata, store) => {
|
||||
let parsed: RefImageState[] | null = null;
|
||||
try {
|
||||
// First attempt to parse from the v6 slot
|
||||
const raw = getProperty(metadata, 'ref_images');
|
||||
// This validator fetches all referenced images. If any do not exist, validation fails. The logic for this is in
|
||||
// the zImageWithDims schema.
|
||||
const parsed = await z.array(zRefImageState).parseAsync(raw);
|
||||
return Promise.resolve(parsed);
|
||||
parsed = z.array(zRefImageState).parse(raw);
|
||||
} catch {
|
||||
// Fall back to extracting from canvas metadata]
|
||||
const raw = getProperty(metadata, 'canvas_v2_metadata.referenceImages.entities');
|
||||
// This validator fetches all referenced images. If any do not exist, validation fails. The logic for this is in
|
||||
// the zImageWithDims schema.
|
||||
const oldParsed = await z.array(zCanvasReferenceImageState_OLD).parseAsync(raw);
|
||||
const parsed: RefImageState[] = oldParsed.map(({ id, ipAdapter, isEnabled }) => ({
|
||||
parsed = oldParsed.map(({ id, ipAdapter, isEnabled }) => ({
|
||||
id,
|
||||
config: ipAdapter,
|
||||
isEnabled,
|
||||
}));
|
||||
return parsed;
|
||||
}
|
||||
|
||||
if (!parsed) {
|
||||
throw new Error('No valid reference images found in metadata');
|
||||
}
|
||||
|
||||
for (const refImage of parsed) {
|
||||
if (refImage.config.image) {
|
||||
await throwIfImageDoesNotExist(refImage.config.image.image_name, store);
|
||||
}
|
||||
if (refImage.config.model) {
|
||||
await throwIfModelDoesNotExist(refImage.config.model.key, store);
|
||||
}
|
||||
}
|
||||
|
||||
return parsed;
|
||||
},
|
||||
recall: (value, store) => {
|
||||
const entities = value.map((data) => ({ ...data, id: getPrefixedId('reference_image') }));
|
||||
@ -1241,3 +1298,19 @@ const isCompatibleWithMainModel = (candidate: ModelIdentifierField, store: AppSt
|
||||
}
|
||||
return candidate.base === base;
|
||||
};
|
||||
|
||||
const throwIfImageDoesNotExist = async (name: string, store: AppStore): Promise<void> => {
|
||||
try {
|
||||
await store.dispatch(imagesApi.endpoints.getImageDTO.initiate(name, { subscribe: false })).unwrap();
|
||||
} catch {
|
||||
throw new Error(`Image with name ${name} does not exist`);
|
||||
}
|
||||
};
|
||||
|
||||
const throwIfModelDoesNotExist = async (key: string, store: AppStore): Promise<void> => {
|
||||
try {
|
||||
await store.dispatch(modelsApi.endpoints.getModelConfig.initiate(key, { subscribe: false }));
|
||||
} catch {
|
||||
throw new Error(`Model with key ${key} does not exist`);
|
||||
}
|
||||
};
|
||||
|
@ -1,7 +1,6 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
/**
|
||||
* Raised when a model config is unable to be fetched.
|
||||
@ -47,45 +46,6 @@ const fetchModelConfig = async (key: string): Promise<AnyModelConfig> => {
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches the model config for a given model name, base model, and model type. This provides backwards compatibility
|
||||
* for MM1 model identifiers.
|
||||
* @param name The model name.
|
||||
* @param base The base model.
|
||||
* @param type The model type.
|
||||
* @returns A promise that resolves to the model config.
|
||||
* @throws {ModelConfigNotFoundError} If the model config is unable to be fetched.
|
||||
*/
|
||||
const fetchModelConfigByAttrs = async (name: string, base: BaseModelType, type: ModelType): Promise<AnyModelConfig> => {
|
||||
const { dispatch } = getStore();
|
||||
try {
|
||||
const req = dispatch(
|
||||
modelsApi.endpoints.getModelConfigByAttrs.initiate({ name, base, type }, { subscribe: false })
|
||||
);
|
||||
return await req.unwrap();
|
||||
} catch {
|
||||
throw new ModelConfigNotFoundError(`Unable to retrieve model config for name/base/type ${name}/${base}/${type}`);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches the model config given an identifier. First attempts to fetch by key, then falls back to fetching by attrs.
|
||||
* @param identifier The model identifier.
|
||||
* @returns A promise that resolves to the model config.
|
||||
* @throws {ModelConfigNotFoundError} If the model config is unable to be fetched.
|
||||
*/
|
||||
export const fetchModelConfigByIdentifier = async (identifier: ModelIdentifierField): Promise<AnyModelConfig> => {
|
||||
try {
|
||||
return await fetchModelConfig(identifier.key);
|
||||
} catch {
|
||||
try {
|
||||
return await fetchModelConfigByAttrs(identifier.name, identifier.base, identifier.type);
|
||||
} catch {
|
||||
throw new ModelConfigNotFoundError(`Unable to retrieve model config for identifier ${identifier}`);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches the model config for a given model key and type, and ensures that the model config is of a specific type.
|
||||
* @param key The model key.
|
||||
|
@ -1,21 +1,28 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSelector, createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import type { ModelType } from 'services/api/types';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { SliceConfig } from 'app/store/types';
|
||||
import { isPlainObject } from 'es-toolkit';
|
||||
import { zModelType } from 'features/nodes/types/common';
|
||||
import { assert } from 'tsafe';
|
||||
import z from 'zod';
|
||||
|
||||
export type FilterableModelType = Exclude<ModelType, 'onnx'> | 'refiner';
|
||||
const zFilterableModelType = zModelType.exclude(['onnx']).or(z.literal('refiner'));
|
||||
export type FilterableModelType = z.infer<typeof zFilterableModelType>;
|
||||
|
||||
type ModelManagerState = {
|
||||
_version: 1;
|
||||
selectedModelKey: string | null;
|
||||
selectedModelMode: 'edit' | 'view';
|
||||
searchTerm: string;
|
||||
filteredModelType: FilterableModelType | null;
|
||||
scanPath: string | undefined;
|
||||
shouldInstallInPlace: boolean;
|
||||
};
|
||||
const zModelManagerState = z.object({
|
||||
_version: z.literal(1),
|
||||
selectedModelKey: z.string().nullable(),
|
||||
selectedModelMode: z.enum(['edit', 'view']),
|
||||
searchTerm: z.string(),
|
||||
filteredModelType: zFilterableModelType.nullable(),
|
||||
scanPath: z.string().optional(),
|
||||
shouldInstallInPlace: z.boolean(),
|
||||
});
|
||||
|
||||
const initialModelManagerState: ModelManagerState = {
|
||||
type ModelManagerState = z.infer<typeof zModelManagerState>;
|
||||
|
||||
const getInitialState = (): ModelManagerState => ({
|
||||
_version: 1,
|
||||
selectedModelKey: null,
|
||||
selectedModelMode: 'view',
|
||||
@ -23,11 +30,11 @@ const initialModelManagerState: ModelManagerState = {
|
||||
searchTerm: '',
|
||||
scanPath: undefined,
|
||||
shouldInstallInPlace: true,
|
||||
};
|
||||
});
|
||||
|
||||
export const modelManagerV2Slice = createSlice({
|
||||
const slice = createSlice({
|
||||
name: 'modelmanagerV2',
|
||||
initialState: initialModelManagerState,
|
||||
initialState: getInitialState(),
|
||||
reducers: {
|
||||
setSelectedModelKey: (state, action: PayloadAction<string | null>) => {
|
||||
state.selectedModelMode = 'view';
|
||||
@ -58,21 +65,22 @@ export const {
|
||||
setSelectedModelMode,
|
||||
setScanPath,
|
||||
shouldInstallInPlaceChanged,
|
||||
} = modelManagerV2Slice.actions;
|
||||
} = slice.actions;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrateModelManagerState = (state: any): any => {
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
||||
export const modelManagerV2PersistConfig: PersistConfig<ModelManagerState> = {
|
||||
name: modelManagerV2Slice.name,
|
||||
initialState: initialModelManagerState,
|
||||
migrate: migrateModelManagerState,
|
||||
persistDenylist: ['selectedModelKey', 'selectedModelMode', 'filteredModelType', 'searchTerm'],
|
||||
export const modelManagerSliceConfig: SliceConfig<typeof slice> = {
|
||||
slice,
|
||||
schema: zModelManagerState,
|
||||
getInitialState,
|
||||
persistConfig: {
|
||||
migrate: (state) => {
|
||||
assert(isPlainObject(state));
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
return zModelManagerState.parse(state);
|
||||
},
|
||||
persistDenylist: ['selectedModelKey', 'selectedModelMode', 'filteredModelType', 'searchTerm'],
|
||||
},
|
||||
};
|
||||
|
||||
export const selectModelManagerV2Slice = (state: RootState) => state.modelmanagerV2;
|
||||
|
@ -14,7 +14,13 @@ import type {
|
||||
ReactFlowProps,
|
||||
ReactFlowState,
|
||||
} from '@xyflow/react';
|
||||
import { Background, ReactFlow, useStore as useReactFlowStore, useUpdateNodeInternals } from '@xyflow/react';
|
||||
import {
|
||||
Background,
|
||||
ReactFlow,
|
||||
SelectionMode,
|
||||
useStore as useReactFlowStore,
|
||||
useUpdateNodeInternals,
|
||||
} from '@xyflow/react';
|
||||
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
|
||||
import { $isSelectingOutputNode, $outputNodeId } from 'features/nodes/components/sidePanel/workflow/publish';
|
||||
@ -256,7 +262,7 @@ export const Flow = memo(() => {
|
||||
style={flowStyles}
|
||||
onPaneClick={handlePaneClick}
|
||||
deleteKeyCode={null}
|
||||
selectionMode={selectionMode}
|
||||
selectionMode={selectionMode === 'full' ? SelectionMode.Full : SelectionMode.Partial}
|
||||
elevateEdgesOnSelect
|
||||
nodeDragThreshold={1}
|
||||
noDragClassName={NO_DRAG_CLASS}
|
||||
|
@ -11,14 +11,15 @@ import type {
|
||||
XYPosition,
|
||||
} from '@xyflow/react';
|
||||
import { applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from '@xyflow/react';
|
||||
import type { PersistConfig } from 'app/store/store';
|
||||
import type { SliceConfig } from 'app/store/types';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { isPlainObject } from 'es-toolkit';
|
||||
import {
|
||||
addElement,
|
||||
removeElement,
|
||||
reparentElement,
|
||||
} from 'features/nodes/components/sidePanel/builder/form-manipulation';
|
||||
import type { NodesState } from 'features/nodes/store/types';
|
||||
import { type NodesState, zNodesState } from 'features/nodes/store/types';
|
||||
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
|
||||
import type {
|
||||
BoardFieldValue,
|
||||
@ -127,6 +128,7 @@ import {
|
||||
import { atom, computed } from 'nanostores';
|
||||
import type { MouseEvent } from 'react';
|
||||
import type { UndoableOptions } from 'redux-undo';
|
||||
import { assert } from 'tsafe';
|
||||
import type { z } from 'zod';
|
||||
|
||||
import type { PendingConnection, Templates } from './types';
|
||||
@ -151,11 +153,11 @@ export const getInitialWorkflow = (): Omit<NodesState, 'mode' | 'formFieldInitia
|
||||
};
|
||||
};
|
||||
|
||||
const initialState: NodesState = {
|
||||
const getInitialState = (): NodesState => ({
|
||||
_version: 1,
|
||||
formFieldInitialValues: {},
|
||||
...getInitialWorkflow(),
|
||||
};
|
||||
});
|
||||
|
||||
type FieldValueAction<T extends FieldValue> = PayloadAction<{
|
||||
nodeId: string;
|
||||
@ -208,9 +210,9 @@ const fieldValueReducer = <T extends FieldValue>(
|
||||
field.value = result.data;
|
||||
};
|
||||
|
||||
export const nodesSlice = createSlice({
|
||||
const slice = createSlice({
|
||||
name: 'nodes',
|
||||
initialState: initialState,
|
||||
initialState: getInitialState(),
|
||||
reducers: {
|
||||
nodesChanged: (state, action: PayloadAction<NodeChange<AnyNode>[]>) => {
|
||||
// In v12.7.0, @xyflow/react added a `domAttributes` property to the node data. One DOM attribute is
|
||||
@ -588,7 +590,7 @@ export const nodesSlice = createSlice({
|
||||
}
|
||||
node.data.notes = value;
|
||||
},
|
||||
nodeEditorReset: () => deepClone(initialState),
|
||||
nodeEditorReset: () => getInitialState(),
|
||||
workflowNameChanged: (state, action: PayloadAction<string>) => {
|
||||
state.name = action.payload;
|
||||
},
|
||||
@ -673,7 +675,7 @@ export const nodesSlice = createSlice({
|
||||
const formFieldInitialValues = getFormFieldInitialValues(workflowExtra.form, nodes);
|
||||
|
||||
return {
|
||||
...deepClone(initialState),
|
||||
...getInitialState(),
|
||||
...deepClone(workflowExtra),
|
||||
formFieldInitialValues,
|
||||
nodes: nodes.map((node) => ({ ...SHARED_NODE_PROPERTIES, ...node })),
|
||||
@ -758,7 +760,7 @@ export const {
|
||||
workflowLoaded,
|
||||
undo,
|
||||
redo,
|
||||
} = nodesSlice.actions;
|
||||
} = slice.actions;
|
||||
|
||||
export const $cursorPos = atom<XYPosition | null>(null);
|
||||
export const $templates = atom<Templates>({});
|
||||
@ -775,21 +777,6 @@ export const $lastEdgeUpdateMouseEvent = atom<MouseEvent | null>(null);
|
||||
export const $viewport = atom<Viewport>({ x: 0, y: 0, zoom: 1 });
|
||||
export const $addNodeCmdk = atom(false);
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrateNodesState = (state: any): any => {
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
||||
export const nodesPersistConfig: PersistConfig<NodesState> = {
|
||||
name: nodesSlice.name,
|
||||
initialState: initialState,
|
||||
migrate: migrateNodesState,
|
||||
persistDenylist: [],
|
||||
};
|
||||
|
||||
type NodeSelectionAction = {
|
||||
type: ReturnType<typeof nodesChanged>['type'];
|
||||
payload: NodeSelectionChange[];
|
||||
@ -893,10 +880,10 @@ const isHighFrequencyWorkflowDetailsAction = isAnyOf(
|
||||
// a note in a notes node, we don't want to create a new undo group for every keystroke.
|
||||
const isHighFrequencyNodeScopedAction = isAnyOf(nodeLabelChanged, nodeNotesChanged, notesNodeValueChanged);
|
||||
|
||||
export const nodesUndoableConfig: UndoableOptions<NodesState, UnknownAction> = {
|
||||
const reduxUndoOptions: UndoableOptions<NodesState, UnknownAction> = {
|
||||
limit: 64,
|
||||
undoType: nodesSlice.actions.undo.type,
|
||||
redoType: nodesSlice.actions.redo.type,
|
||||
undoType: slice.actions.undo.type,
|
||||
redoType: slice.actions.redo.type,
|
||||
groupBy: (action, _state, _history) => {
|
||||
if (isHighFrequencyFieldChangeAction(action)) {
|
||||
// Group by type, node id and field name
|
||||
@ -928,7 +915,7 @@ export const nodesUndoableConfig: UndoableOptions<NodesState, UnknownAction> = {
|
||||
},
|
||||
filter: (action, _state, _history) => {
|
||||
// Ignore all actions from other slices
|
||||
if (!action.type.startsWith(nodesSlice.name)) {
|
||||
if (!action.type.startsWith(slice.name)) {
|
||||
return false;
|
||||
}
|
||||
// Ignore actions that only select or deselect nodes and edges
|
||||
@ -943,6 +930,24 @@ export const nodesUndoableConfig: UndoableOptions<NodesState, UnknownAction> = {
|
||||
},
|
||||
};
|
||||
|
||||
export const nodesSliceConfig: SliceConfig<typeof slice> = {
|
||||
slice,
|
||||
schema: zNodesState,
|
||||
getInitialState,
|
||||
persistConfig: {
|
||||
migrate: (state) => {
|
||||
assert(isPlainObject(state));
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
return zNodesState.parse(state);
|
||||
},
|
||||
},
|
||||
undoableConfig: {
|
||||
reduxUndoOptions,
|
||||
},
|
||||
};
|
||||
|
||||
// The form builder's initial values are based on the current values of the node fields in the workflow.
|
||||
export const getFormFieldInitialValues = (form: BuilderForm, nodes: NodesState['nodes']) => {
|
||||
const formFieldInitialValues: Record<string, StatefulFieldValue> = {};
|
||||
|
@ -1,7 +1,8 @@
|
||||
import type { HandleType } from '@xyflow/react';
|
||||
import type { FieldInputTemplate, FieldOutputTemplate, StatefulFieldValue } from 'features/nodes/types/field';
|
||||
import type { AnyEdge, AnyNode, InvocationTemplate, NodeExecutionState } from 'features/nodes/types/invocation';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { type FieldInputTemplate, type FieldOutputTemplate, zStatefulFieldValue } from 'features/nodes/types/field';
|
||||
import { type InvocationTemplate, type NodeExecutionState, zAnyEdge, zAnyNode } from 'features/nodes/types/invocation';
|
||||
import { zWorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import z from 'zod';
|
||||
|
||||
export type Templates = Record<string, InvocationTemplate>;
|
||||
export type NodeExecutionStates = Record<string, NodeExecutionState | undefined>;
|
||||
@ -13,11 +14,13 @@ export type PendingConnection = {
|
||||
fieldTemplate: FieldInputTemplate | FieldOutputTemplate;
|
||||
};
|
||||
|
||||
export type WorkflowMode = 'edit' | 'view';
|
||||
|
||||
export type NodesState = {
|
||||
_version: 1;
|
||||
nodes: AnyNode[];
|
||||
edges: AnyEdge[];
|
||||
formFieldInitialValues: Record<string, StatefulFieldValue>;
|
||||
} & Omit<WorkflowV3, 'nodes' | 'edges' | 'is_published'>;
|
||||
export const zWorkflowMode = z.enum(['edit', 'view']);
|
||||
export type WorkflowMode = z.infer<typeof zWorkflowMode>;
|
||||
export const zNodesState = z.object({
|
||||
_version: z.literal(1),
|
||||
nodes: z.array(zAnyNode),
|
||||
edges: z.array(zAnyEdge),
|
||||
formFieldInitialValues: z.record(z.string(), zStatefulFieldValue),
|
||||
...zWorkflowV3.omit({ nodes: true, edges: true, is_published: true }).shape,
|
||||
});
|
||||
export type NodesState = z.infer<typeof zNodesState>;
|
||||
|
@ -1,34 +1,43 @@
|
||||
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
|
||||
import { createSelector, createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import type { WorkflowMode } from 'features/nodes/store/types';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { SliceConfig } from 'app/store/types';
|
||||
import { type WorkflowMode, zWorkflowMode } from 'features/nodes/store/types';
|
||||
import type { WorkflowCategory } from 'features/nodes/types/workflow';
|
||||
import { atom, computed } from 'nanostores';
|
||||
import type { SQLiteDirection, WorkflowRecordOrderBy } from 'services/api/types';
|
||||
import {
|
||||
type SQLiteDirection,
|
||||
type WorkflowRecordOrderBy,
|
||||
zSQLiteDirection,
|
||||
zWorkflowRecordOrderBy,
|
||||
} from 'services/api/types';
|
||||
import z from 'zod';
|
||||
|
||||
export type WorkflowLibraryView = 'recent' | 'yours' | 'private' | 'shared' | 'defaults' | 'published';
|
||||
const zWorkflowLibraryView = z.enum(['recent', 'yours', 'private', 'shared', 'defaults', 'published']);
|
||||
export type WorkflowLibraryView = z.infer<typeof zWorkflowLibraryView>;
|
||||
|
||||
type WorkflowLibraryState = {
|
||||
mode: WorkflowMode;
|
||||
view: WorkflowLibraryView;
|
||||
orderBy: WorkflowRecordOrderBy;
|
||||
direction: SQLiteDirection;
|
||||
searchTerm: string;
|
||||
selectedTags: string[];
|
||||
};
|
||||
const zWorkflowLibraryState = z.object({
|
||||
mode: zWorkflowMode,
|
||||
view: zWorkflowLibraryView,
|
||||
orderBy: zWorkflowRecordOrderBy,
|
||||
direction: zSQLiteDirection,
|
||||
searchTerm: z.string(),
|
||||
selectedTags: z.array(z.string()),
|
||||
});
|
||||
type WorkflowLibraryState = z.infer<typeof zWorkflowLibraryState>;
|
||||
|
||||
const initialWorkflowLibraryState: WorkflowLibraryState = {
|
||||
const getInitialState = (): WorkflowLibraryState => ({
|
||||
mode: 'view',
|
||||
searchTerm: '',
|
||||
orderBy: 'opened_at',
|
||||
direction: 'DESC',
|
||||
selectedTags: [],
|
||||
view: 'defaults',
|
||||
};
|
||||
});
|
||||
|
||||
export const workflowLibrarySlice = createSlice({
|
||||
const slice = createSlice({
|
||||
name: 'workflowLibrary',
|
||||
initialState: initialWorkflowLibraryState,
|
||||
initialState: getInitialState(),
|
||||
reducers: {
|
||||
workflowModeChanged: (state, action: PayloadAction<WorkflowMode>) => {
|
||||
state.mode = action.payload;
|
||||
@ -73,16 +82,15 @@ export const {
|
||||
workflowLibraryTagToggled,
|
||||
workflowLibraryTagsReset,
|
||||
workflowLibraryViewChanged,
|
||||
} = workflowLibrarySlice.actions;
|
||||
} = slice.actions;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrateWorkflowLibraryState = (state: any): any => state;
|
||||
|
||||
export const workflowLibraryPersistConfig: PersistConfig<WorkflowLibraryState> = {
|
||||
name: workflowLibrarySlice.name,
|
||||
initialState: initialWorkflowLibraryState,
|
||||
migrate: migrateWorkflowLibraryState,
|
||||
persistDenylist: [],
|
||||
export const workflowLibrarySliceConfig: SliceConfig<typeof slice> = {
|
||||
slice,
|
||||
schema: zWorkflowLibraryState,
|
||||
getInitialState,
|
||||
persistConfig: {
|
||||
migrate: (state) => zWorkflowLibraryState.parse(state),
|
||||
},
|
||||
};
|
||||
|
||||
const selectWorkflowLibrarySlice = (state: RootState) => state.workflowLibrary;
|
||||
|
@ -1,8 +1,10 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSelector, createSlice } from '@reduxjs/toolkit';
|
||||
import { SelectionMode } from '@xyflow/react';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { SliceConfig } from 'app/store/types';
|
||||
import { isPlainObject } from 'es-toolkit';
|
||||
import type { Selector } from 'react-redux';
|
||||
import { assert } from 'tsafe';
|
||||
import z from 'zod';
|
||||
|
||||
export const zLayeringStrategy = z.enum(['network-simplex', 'longest-path']);
|
||||
@ -11,25 +13,28 @@ export const zLayoutDirection = z.enum(['TB', 'LR']);
|
||||
type LayoutDirection = z.infer<typeof zLayoutDirection>;
|
||||
export const zNodeAlignment = z.enum(['UL', 'UR', 'DL', 'DR']);
|
||||
type NodeAlignment = z.infer<typeof zNodeAlignment>;
|
||||
const zSelectionMode = z.enum(['partial', 'full']);
|
||||
|
||||
export type WorkflowSettingsState = {
|
||||
_version: 1;
|
||||
shouldShowMinimapPanel: boolean;
|
||||
layeringStrategy: LayeringStrategy;
|
||||
nodeSpacing: number;
|
||||
layerSpacing: number;
|
||||
layoutDirection: LayoutDirection;
|
||||
shouldValidateGraph: boolean;
|
||||
shouldAnimateEdges: boolean;
|
||||
nodeAlignment: NodeAlignment;
|
||||
nodeOpacity: number;
|
||||
shouldSnapToGrid: boolean;
|
||||
shouldColorEdges: boolean;
|
||||
shouldShowEdgeLabels: boolean;
|
||||
selectionMode: SelectionMode;
|
||||
};
|
||||
const zWorkflowSettingsState = z.object({
|
||||
_version: z.literal(1),
|
||||
shouldShowMinimapPanel: z.boolean(),
|
||||
layeringStrategy: zLayeringStrategy,
|
||||
nodeSpacing: z.number(),
|
||||
layerSpacing: z.number(),
|
||||
layoutDirection: zLayoutDirection,
|
||||
shouldValidateGraph: z.boolean(),
|
||||
shouldAnimateEdges: z.boolean(),
|
||||
nodeAlignment: zNodeAlignment,
|
||||
nodeOpacity: z.number(),
|
||||
shouldSnapToGrid: z.boolean(),
|
||||
shouldColorEdges: z.boolean(),
|
||||
shouldShowEdgeLabels: z.boolean(),
|
||||
selectionMode: zSelectionMode,
|
||||
});
|
||||
|
||||
const initialState: WorkflowSettingsState = {
|
||||
export type WorkflowSettingsState = z.infer<typeof zWorkflowSettingsState>;
|
||||
|
||||
const getInitialState = (): WorkflowSettingsState => ({
|
||||
_version: 1,
|
||||
shouldShowMinimapPanel: true,
|
||||
layeringStrategy: 'network-simplex',
|
||||
@ -43,12 +48,12 @@ const initialState: WorkflowSettingsState = {
|
||||
shouldColorEdges: true,
|
||||
shouldShowEdgeLabels: false,
|
||||
nodeOpacity: 1,
|
||||
selectionMode: SelectionMode.Partial,
|
||||
};
|
||||
selectionMode: 'partial',
|
||||
});
|
||||
|
||||
export const workflowSettingsSlice = createSlice({
|
||||
const slice = createSlice({
|
||||
name: 'workflowSettings',
|
||||
initialState,
|
||||
initialState: getInitialState(),
|
||||
reducers: {
|
||||
shouldShowMinimapPanelChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldShowMinimapPanel = action.payload;
|
||||
@ -87,7 +92,7 @@ export const workflowSettingsSlice = createSlice({
|
||||
state.nodeAlignment = action.payload;
|
||||
},
|
||||
selectionModeChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.selectionMode = action.payload ? SelectionMode.Full : SelectionMode.Partial;
|
||||
state.selectionMode = action.payload ? 'full' : 'partial';
|
||||
},
|
||||
},
|
||||
});
|
||||
@ -106,21 +111,21 @@ export const {
|
||||
shouldValidateGraphChanged,
|
||||
nodeOpacityChanged,
|
||||
selectionModeChanged,
|
||||
} = workflowSettingsSlice.actions;
|
||||
} = slice.actions;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrateWorkflowSettingsState = (state: any): any => {
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
||||
export const workflowSettingsPersistConfig: PersistConfig<WorkflowSettingsState> = {
|
||||
name: workflowSettingsSlice.name,
|
||||
initialState,
|
||||
migrate: migrateWorkflowSettingsState,
|
||||
persistDenylist: [],
|
||||
export const workflowSettingsSliceConfig: SliceConfig<typeof slice> = {
|
||||
slice,
|
||||
schema: zWorkflowSettingsState,
|
||||
getInitialState,
|
||||
persistConfig: {
|
||||
migrate: (state) => {
|
||||
assert(isPlainObject(state));
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
return zWorkflowSettingsState.parse(state);
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const selectWorkflowSettingsSlice = (state: RootState) => state.workflowSettings;
|
||||
|
@ -92,7 +92,7 @@ export const zMainModelBase = z.enum([
|
||||
]);
|
||||
type MainModelBase = z.infer<typeof zMainModelBase>;
|
||||
export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success;
|
||||
const zModelType = z.enum([
|
||||
export const zModelType = z.enum([
|
||||
'main',
|
||||
'vae',
|
||||
'lora',
|
||||
|
@ -43,7 +43,7 @@ export const zNotesNodeData = z.object({
|
||||
isOpen: z.boolean(),
|
||||
notes: z.string(),
|
||||
});
|
||||
const _zCurrentImageNodeData = z.object({
|
||||
const zCurrentImageNodeData = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
type: z.literal('current_image'),
|
||||
label: z.string(),
|
||||
@ -52,12 +52,35 @@ const _zCurrentImageNodeData = z.object({
|
||||
|
||||
export type NotesNodeData = z.infer<typeof zNotesNodeData>;
|
||||
export type InvocationNodeData = z.infer<typeof zInvocationNodeData>;
|
||||
type CurrentImageNodeData = z.infer<typeof _zCurrentImageNodeData>;
|
||||
type CurrentImageNodeData = z.infer<typeof zCurrentImageNodeData>;
|
||||
|
||||
export type InvocationNode = Node<InvocationNodeData, 'invocation'>;
|
||||
export type NotesNode = Node<NotesNodeData, 'notes'>;
|
||||
export type CurrentImageNode = Node<CurrentImageNodeData, 'current_image'>;
|
||||
export type AnyNode = InvocationNode | NotesNode | CurrentImageNode;
|
||||
const zInvocationNodeValidationSchema = z.looseObject({
|
||||
type: z.literal('invocation'),
|
||||
data: zInvocationNodeData,
|
||||
});
|
||||
const zInvocationNode = z.custom<Node<InvocationNodeData, 'invocation'>>(
|
||||
(val) => zInvocationNodeValidationSchema.safeParse(val).success
|
||||
);
|
||||
export type InvocationNode = z.infer<typeof zInvocationNode>;
|
||||
|
||||
const zNotesNodeValidationSchema = z.looseObject({
|
||||
type: z.literal('notes'),
|
||||
data: zNotesNodeData,
|
||||
});
|
||||
const zNotesNode = z.custom<Node<NotesNodeData, 'notes'>>((val) => zNotesNodeValidationSchema.safeParse(val).success);
|
||||
export type NotesNode = z.infer<typeof zNotesNode>;
|
||||
|
||||
const zCurrentImageNodeValidationSchema = z.looseObject({
|
||||
type: z.literal('current_image'),
|
||||
data: zCurrentImageNodeData,
|
||||
});
|
||||
const zCurrentImageNode = z.custom<Node<CurrentImageNodeData, 'current_image'>>(
|
||||
(val) => zCurrentImageNodeValidationSchema.safeParse(val).success
|
||||
);
|
||||
export type CurrentImageNode = z.infer<typeof zCurrentImageNode>;
|
||||
|
||||
export const zAnyNode = z.union([zInvocationNode, zNotesNode, zCurrentImageNode]);
|
||||
export type AnyNode = z.infer<typeof zAnyNode>;
|
||||
|
||||
export const isInvocationNode = (node?: AnyNode | null): node is InvocationNode =>
|
||||
Boolean(node && node.type === 'invocation');
|
||||
@ -83,13 +106,29 @@ export type NodeExecutionState = z.infer<typeof _zNodeExecutionState>;
|
||||
// #endregion
|
||||
|
||||
// #region Edges
|
||||
const _zInvocationNodeEdgeCollapsedData = z.object({
|
||||
const zDefaultInvocationNodeEdgeValidationSchema = z.looseObject({
|
||||
type: z.literal('default'),
|
||||
});
|
||||
const zDefaultInvocationNodeEdge = z.custom<Edge<Record<string, never>, 'default'>>(
|
||||
(val) => zDefaultInvocationNodeEdgeValidationSchema.safeParse(val).success
|
||||
);
|
||||
export type DefaultInvocationNodeEdge = z.infer<typeof zDefaultInvocationNodeEdge>;
|
||||
|
||||
const zInvocationNodeEdgeCollapsedData = z.object({
|
||||
count: z.number().int().min(1),
|
||||
});
|
||||
type InvocationNodeEdgeCollapsedData = z.infer<typeof _zInvocationNodeEdgeCollapsedData>;
|
||||
export type DefaultInvocationNodeEdge = Edge<Record<string, never>, 'default'>;
|
||||
export type CollapsedInvocationNodeEdge = Edge<InvocationNodeEdgeCollapsedData, 'collapsed'>;
|
||||
export type AnyEdge = DefaultInvocationNodeEdge | CollapsedInvocationNodeEdge;
|
||||
const zInvocationNodeEdgeCollapsedValidationSchema = z.looseObject({
|
||||
type: z.literal('default'),
|
||||
data: zInvocationNodeEdgeCollapsedData,
|
||||
});
|
||||
type InvocationNodeEdgeCollapsedData = z.infer<typeof zInvocationNodeEdgeCollapsedData>;
|
||||
|
||||
const zCollapsedInvocationNodeEdge = z.custom<Edge<InvocationNodeEdgeCollapsedData, 'collapsed'>>(
|
||||
(val) => zInvocationNodeEdgeCollapsedValidationSchema.safeParse(val).success
|
||||
);
|
||||
export type CollapsedInvocationNodeEdge = z.infer<typeof zCollapsedInvocationNodeEdge>;
|
||||
export const zAnyEdge = z.union([zDefaultInvocationNodeEdge, zCollapsedInvocationNodeEdge]);
|
||||
export type AnyEdge = z.infer<typeof zAnyEdge>;
|
||||
// #endregion
|
||||
|
||||
export const isBatchNodeType = (type: string) =>
|
||||
|
@ -1,4 +1,5 @@
|
||||
import { FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { selectBase } from 'features/controlLayers/store/paramsSlice';
|
||||
@ -6,13 +7,35 @@ import { ModelPicker } from 'features/parameters/components/ModelPicker';
|
||||
import { selectTileControlNetModel, tileControlnetModelChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
|
||||
import { useControlNetModels } from 'services/api/hooks/modelsByType';
|
||||
import type { ControlNetModelConfig } from 'services/api/types';
|
||||
import { type ControlNetModelConfig, isControlNetModelConfig } from 'services/api/types';
|
||||
|
||||
const selectTileControlNetModelConfig = createSelector(
|
||||
selectModelConfigsQuery,
|
||||
selectTileControlNetModel,
|
||||
(modelConfigs, modelIdentifierField) => {
|
||||
if (!modelConfigs.data) {
|
||||
return null;
|
||||
}
|
||||
if (!modelIdentifierField) {
|
||||
return null;
|
||||
}
|
||||
const modelConfig = modelConfigsAdapterSelectors.selectById(modelConfigs.data, modelIdentifierField.key);
|
||||
if (!modelConfig) {
|
||||
return null;
|
||||
}
|
||||
if (!isControlNetModelConfig(modelConfig)) {
|
||||
return null;
|
||||
}
|
||||
return modelConfig;
|
||||
}
|
||||
);
|
||||
|
||||
const ParamTileControlNetModel = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const tileControlNetModel = useAppSelector(selectTileControlNetModel);
|
||||
const tileControlNetModel = useAppSelector(selectTileControlNetModelConfig);
|
||||
const currentBaseModel = useAppSelector(selectBase);
|
||||
const [modelConfigs, { isLoading }] = useControlNetModels();
|
||||
|
||||
|
@ -1,21 +1,21 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { ImageWithDims } from 'features/controlLayers/store/types';
|
||||
import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice';
|
||||
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||
import { useMemo } from 'react';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
const createIsTooLargeToUpscaleSelector = (imageDTO?: ImageDTO | null) =>
|
||||
const createIsTooLargeToUpscaleSelector = (imageWithDims?: ImageWithDims | null) =>
|
||||
createSelector(selectUpscaleSlice, selectConfigSlice, (upscale, config) => {
|
||||
const { upscaleModel, scale } = upscale;
|
||||
const { maxUpscaleDimension } = config;
|
||||
|
||||
if (!maxUpscaleDimension || !upscaleModel || !imageDTO) {
|
||||
if (!maxUpscaleDimension || !upscaleModel || !imageWithDims) {
|
||||
// When these are missing, another warning will be shown
|
||||
return false;
|
||||
}
|
||||
|
||||
const { width, height } = imageDTO;
|
||||
const { width, height } = imageWithDims;
|
||||
|
||||
const maxPixels = maxUpscaleDimension ** 2;
|
||||
const upscaledPixels = width * scale * height * scale;
|
||||
@ -23,7 +23,7 @@ const createIsTooLargeToUpscaleSelector = (imageDTO?: ImageDTO | null) =>
|
||||
return upscaledPixels > maxPixels;
|
||||
});
|
||||
|
||||
export const useIsTooLargeToUpscale = (imageDTO?: ImageDTO | null) => {
|
||||
const selectIsTooLargeToUpscale = useMemo(() => createIsTooLargeToUpscaleSelector(imageDTO), [imageDTO]);
|
||||
export const useIsTooLargeToUpscale = (imageWithDims?: ImageWithDims | null) => {
|
||||
const selectIsTooLargeToUpscale = useMemo(() => createIsTooLargeToUpscaleSelector(imageWithDims), [imageWithDims]);
|
||||
return useAppSelector(selectIsTooLargeToUpscale);
|
||||
};
|
||||
|
@ -1,24 +1,33 @@
|
||||
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
|
||||
import { createSelector, createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { SliceConfig } from 'app/store/types';
|
||||
import { isPlainObject } from 'es-toolkit';
|
||||
import type { ImageWithDims } from 'features/controlLayers/store/types';
|
||||
import { zImageWithDims } from 'features/controlLayers/store/types';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import type { ParameterSpandrelImageToImageModel } from 'features/parameters/types/parameterSchemas';
|
||||
import type { ControlNetModelConfig, ImageDTO } from 'services/api/types';
|
||||
import type { ControlNetModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
import z from 'zod';
|
||||
|
||||
export interface UpscaleState {
|
||||
_version: 1;
|
||||
upscaleModel: ParameterSpandrelImageToImageModel | null;
|
||||
upscaleInitialImage: ImageDTO | null;
|
||||
structure: number;
|
||||
creativity: number;
|
||||
tileControlnetModel: ControlNetModelConfig | null;
|
||||
scale: number;
|
||||
postProcessingModel: ParameterSpandrelImageToImageModel | null;
|
||||
tileSize: number;
|
||||
tileOverlap: number;
|
||||
}
|
||||
const zUpscaleState = z.object({
|
||||
_version: z.literal(2),
|
||||
upscaleModel: zModelIdentifierField.nullable(),
|
||||
upscaleInitialImage: zImageWithDims.nullable(),
|
||||
structure: z.number(),
|
||||
creativity: z.number(),
|
||||
tileControlnetModel: zModelIdentifierField.nullable(),
|
||||
scale: z.number(),
|
||||
postProcessingModel: zModelIdentifierField.nullable(),
|
||||
tileSize: z.number(),
|
||||
tileOverlap: z.number(),
|
||||
});
|
||||
|
||||
const initialUpscaleState: UpscaleState = {
|
||||
_version: 1,
|
||||
export type UpscaleState = z.infer<typeof zUpscaleState>;
|
||||
|
||||
const getInitialState = (): UpscaleState => ({
|
||||
_version: 2,
|
||||
upscaleModel: null,
|
||||
upscaleInitialImage: null,
|
||||
structure: 0,
|
||||
@ -28,16 +37,19 @@ const initialUpscaleState: UpscaleState = {
|
||||
postProcessingModel: null,
|
||||
tileSize: 1024,
|
||||
tileOverlap: 128,
|
||||
};
|
||||
});
|
||||
|
||||
export const upscaleSlice = createSlice({
|
||||
const slice = createSlice({
|
||||
name: 'upscale',
|
||||
initialState: initialUpscaleState,
|
||||
initialState: getInitialState(),
|
||||
reducers: {
|
||||
upscaleModelChanged: (state, action: PayloadAction<ParameterSpandrelImageToImageModel | null>) => {
|
||||
state.upscaleModel = action.payload;
|
||||
const result = zUpscaleState.shape.upscaleModel.safeParse(action.payload);
|
||||
if (result.success) {
|
||||
state.upscaleModel = result.data;
|
||||
}
|
||||
},
|
||||
upscaleInitialImageChanged: (state, action: PayloadAction<ImageDTO | null>) => {
|
||||
upscaleInitialImageChanged: (state, action: PayloadAction<ImageWithDims | null>) => {
|
||||
state.upscaleInitialImage = action.payload;
|
||||
},
|
||||
structureChanged: (state, action: PayloadAction<number>) => {
|
||||
@ -47,13 +59,19 @@ export const upscaleSlice = createSlice({
|
||||
state.creativity = action.payload;
|
||||
},
|
||||
tileControlnetModelChanged: (state, action: PayloadAction<ControlNetModelConfig | null>) => {
|
||||
state.tileControlnetModel = action.payload;
|
||||
const result = zUpscaleState.shape.tileControlnetModel.safeParse(action.payload);
|
||||
if (result.success) {
|
||||
state.tileControlnetModel = result.data;
|
||||
}
|
||||
},
|
||||
scaleChanged: (state, action: PayloadAction<number>) => {
|
||||
state.scale = action.payload;
|
||||
},
|
||||
postProcessingModelChanged: (state, action: PayloadAction<ParameterSpandrelImageToImageModel | null>) => {
|
||||
state.postProcessingModel = action.payload;
|
||||
const result = zUpscaleState.shape.postProcessingModel.safeParse(action.payload);
|
||||
if (result.success) {
|
||||
state.postProcessingModel = result.data;
|
||||
}
|
||||
},
|
||||
tileSizeChanged: (state, action: PayloadAction<number>) => {
|
||||
state.tileSize = action.payload;
|
||||
@ -74,21 +92,33 @@ export const {
|
||||
postProcessingModelChanged,
|
||||
tileSizeChanged,
|
||||
tileOverlapChanged,
|
||||
} = upscaleSlice.actions;
|
||||
} = slice.actions;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrateUpscaleState = (state: any): any => {
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
||||
export const upscalePersistConfig: PersistConfig<UpscaleState> = {
|
||||
name: upscaleSlice.name,
|
||||
initialState: initialUpscaleState,
|
||||
migrate: migrateUpscaleState,
|
||||
persistDenylist: [],
|
||||
export const upscaleSliceConfig: SliceConfig<typeof slice> = {
|
||||
slice,
|
||||
schema: zUpscaleState,
|
||||
getInitialState,
|
||||
persistConfig: {
|
||||
migrate: (state) => {
|
||||
assert(isPlainObject(state));
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
if (state._version === 1) {
|
||||
state._version = 2;
|
||||
// Migrate from v1 to v2: upscaleInitialImage was an ImageDTO, now it's an ImageWithDims
|
||||
if (state.upscaleInitialImage) {
|
||||
const { image_name, width, height } = state.upscaleInitialImage;
|
||||
state.upscaleInitialImage = {
|
||||
image_name,
|
||||
width,
|
||||
height,
|
||||
};
|
||||
}
|
||||
}
|
||||
return zUpscaleState.parse(state);
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const selectUpscaleSlice = (state: RootState) => state.upscale;
|
||||
|
@ -1,24 +1,27 @@
|
||||
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
|
||||
import { createSelector, createSlice } from '@reduxjs/toolkit';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { SliceConfig } from 'app/store/types';
|
||||
import z from 'zod';
|
||||
|
||||
interface QueueState {
|
||||
listCursor: number | undefined;
|
||||
listPriority: number | undefined;
|
||||
selectedQueueItem: string | undefined;
|
||||
resumeProcessorOnEnqueue: boolean;
|
||||
}
|
||||
const zQueueState = z.object({
|
||||
listCursor: z.number().optional(),
|
||||
listPriority: z.number().optional(),
|
||||
selectedQueueItem: z.string().optional(),
|
||||
resumeProcessorOnEnqueue: z.boolean(),
|
||||
});
|
||||
type QueueState = z.infer<typeof zQueueState>;
|
||||
|
||||
const initialQueueState: QueueState = {
|
||||
const getInitialState = (): QueueState => ({
|
||||
listCursor: undefined,
|
||||
listPriority: undefined,
|
||||
selectedQueueItem: undefined,
|
||||
resumeProcessorOnEnqueue: true,
|
||||
};
|
||||
});
|
||||
|
||||
export const queueSlice = createSlice({
|
||||
const slice = createSlice({
|
||||
name: 'queue',
|
||||
initialState: initialQueueState,
|
||||
initialState: getInitialState(),
|
||||
reducers: {
|
||||
listCursorChanged: (state, action: PayloadAction<number | undefined>) => {
|
||||
state.listCursor = action.payload;
|
||||
@ -33,7 +36,13 @@ export const queueSlice = createSlice({
|
||||
},
|
||||
});
|
||||
|
||||
export const { listCursorChanged, listPriorityChanged, listParamsReset } = queueSlice.actions;
|
||||
export const { listCursorChanged, listPriorityChanged, listParamsReset } = slice.actions;
|
||||
|
||||
export const queueSliceConfig: SliceConfig<typeof slice> = {
|
||||
slice,
|
||||
schema: zQueueState,
|
||||
getInitialState,
|
||||
};
|
||||
|
||||
const selectQueueSlice = (state: RootState) => state.queue;
|
||||
const createQueueSelector = <T>(selector: Selector<QueueState, T>) => createSelector(selectQueueSlice, selector);
|
||||
|
@ -1,6 +1,7 @@
|
||||
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { UploadImageIconButton } from 'common/hooks/useImageUploadButton';
|
||||
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
|
||||
import type { SetUpscaleInitialImageDndTargetData } from 'features/dnd/dnd';
|
||||
import { setUpscaleInitialImageDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
@ -10,11 +11,13 @@ import { selectUpscaleInitialImage, upscaleInitialImageChanged } from 'features/
|
||||
import { t } from 'i18next';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
|
||||
import { useImageDTO } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
export const UpscaleInitialImage = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const imageDTO = useAppSelector(selectUpscaleInitialImage);
|
||||
const upscaleInitialImage = useAppSelector(selectUpscaleInitialImage);
|
||||
const imageDTO = useImageDTO(upscaleInitialImage?.image_name);
|
||||
const dndTargetData = useMemo<SetUpscaleInitialImageDndTargetData>(
|
||||
() => setUpscaleInitialImageDndTarget.getData(),
|
||||
[]
|
||||
@ -26,7 +29,7 @@ export const UpscaleInitialImage = () => {
|
||||
|
||||
const onUpload = useCallback(
|
||||
(imageDTO: ImageDTO) => {
|
||||
dispatch(upscaleInitialImageChanged(imageDTO));
|
||||
dispatch(upscaleInitialImageChanged(imageDTOToImageWithDims(imageDTO)));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
@ -31,8 +31,10 @@ export const UpscaleWarning = () => {
|
||||
const validModel = modelConfigs.find((cnetModel) => {
|
||||
return cnetModel.base === model?.base && cnetModel.name.toLowerCase().includes('tile');
|
||||
});
|
||||
dispatch(tileControlnetModelChanged(validModel || null));
|
||||
}, [model?.base, modelConfigs, dispatch]);
|
||||
if (tileControlnetModel?.key !== validModel?.key) {
|
||||
dispatch(tileControlnetModelChanged(validModel || null));
|
||||
}
|
||||
}, [dispatch, model?.base, modelConfigs, tileControlnetModel?.key]);
|
||||
|
||||
const isBaseModelCompatible = useMemo(() => {
|
||||
return model && ['sd-1', 'sdxl'].includes(model.base);
|
||||
|
@ -1,23 +1,33 @@
|
||||
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
|
||||
import { createSelector, createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { SliceConfig } from 'app/store/types';
|
||||
import { isPlainObject } from 'es-toolkit';
|
||||
import { paramsReset } from 'features/controlLayers/store/paramsSlice';
|
||||
import { atom } from 'nanostores';
|
||||
import { stylePresetsApi } from 'services/api/endpoints/stylePresets';
|
||||
import { assert } from 'tsafe';
|
||||
import z from 'zod';
|
||||
|
||||
import type { StylePresetState } from './types';
|
||||
const zStylePresetState = z.object({
|
||||
activeStylePresetId: z.string().nullable(),
|
||||
searchTerm: z.string(),
|
||||
viewMode: z.boolean(),
|
||||
showPromptPreviews: z.boolean(),
|
||||
});
|
||||
|
||||
const initialState: StylePresetState = {
|
||||
type StylePresetState = z.infer<typeof zStylePresetState>;
|
||||
|
||||
const getInitialState = (): StylePresetState => ({
|
||||
activeStylePresetId: null,
|
||||
searchTerm: '',
|
||||
viewMode: false,
|
||||
showPromptPreviews: false,
|
||||
};
|
||||
});
|
||||
|
||||
export const stylePresetSlice = createSlice({
|
||||
const slice = createSlice({
|
||||
name: 'stylePreset',
|
||||
initialState: initialState,
|
||||
initialState: getInitialState(),
|
||||
reducers: {
|
||||
activeStylePresetIdChanged: (state, action: PayloadAction<string | null>) => {
|
||||
state.activeStylePresetId = action.payload;
|
||||
@ -34,7 +44,7 @@ export const stylePresetSlice = createSlice({
|
||||
},
|
||||
extraReducers(builder) {
|
||||
builder.addCase(paramsReset, () => {
|
||||
return deepClone(initialState);
|
||||
return getInitialState();
|
||||
});
|
||||
builder.addMatcher(stylePresetsApi.endpoints.deleteStylePreset.matchFulfilled, (state, action) => {
|
||||
if (state.activeStylePresetId === null) {
|
||||
@ -58,21 +68,21 @@ export const stylePresetSlice = createSlice({
|
||||
});
|
||||
|
||||
export const { activeStylePresetIdChanged, searchTermChanged, viewModeChanged, showPromptPreviewsChanged } =
|
||||
stylePresetSlice.actions;
|
||||
slice.actions;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrateStylePresetState = (state: any): any => {
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
||||
export const stylePresetPersistConfig: PersistConfig<StylePresetState> = {
|
||||
name: stylePresetSlice.name,
|
||||
initialState,
|
||||
migrate: migrateStylePresetState,
|
||||
persistDenylist: [],
|
||||
export const stylePresetSliceConfig: SliceConfig<typeof slice> = {
|
||||
slice,
|
||||
schema: zStylePresetState,
|
||||
getInitialState,
|
||||
persistConfig: {
|
||||
migrate: (state) => {
|
||||
assert(isPlainObject(state));
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
return zStylePresetState.parse(state);
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const selectStylePresetSlice = (state: RootState) => state.stylePreset;
|
||||
|
@ -1,6 +0,0 @@
|
||||
export type StylePresetState = {
|
||||
activeStylePresetId: string | null;
|
||||
searchTerm: string;
|
||||
viewMode: boolean;
|
||||
showPromptPreviews: boolean;
|
||||
};
|
@ -14,11 +14,11 @@ import {
|
||||
Switch,
|
||||
Text,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useClearStorage } from 'app/contexts/clear-storage-context';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { buildUseBoolean } from 'common/hooks/useBoolean';
|
||||
import { useClearStorage } from 'common/hooks/useClearStorage';
|
||||
import { selectShouldUseCPUNoise, shouldUseCpuNoiseChanged } from 'features/controlLayers/store/paramsSlice';
|
||||
import { useRefreshAfterResetModal } from 'features/system/components/SettingsModal/RefreshAfterResetModal';
|
||||
import { SettingsDeveloperLogIsEnabled } from 'features/system/components/SettingsModal/SettingsDeveloperLogIsEnabled';
|
||||
|
@ -1,193 +1,25 @@
|
||||
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
|
||||
import { createSelector, createSlice } from '@reduxjs/toolkit';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { AppConfig, NumericalParameterConfig, PartialAppConfig } from 'app/types/invokeai';
|
||||
import type { SliceConfig } from 'app/store/types';
|
||||
import { getDefaultAppConfig, type PartialAppConfig, zAppConfig } from 'app/types/invokeai';
|
||||
import { merge } from 'es-toolkit/compat';
|
||||
import z from 'zod';
|
||||
|
||||
const baseDimensionConfig: NumericalParameterConfig = {
|
||||
initial: 512, // determined by model selection, unused in practice
|
||||
sliderMin: 64,
|
||||
sliderMax: 1536,
|
||||
numberInputMin: 64,
|
||||
numberInputMax: 4096,
|
||||
fineStep: 8,
|
||||
coarseStep: 64,
|
||||
};
|
||||
const zConfigState = z.object({
|
||||
...zAppConfig.shape,
|
||||
didLoad: z.boolean(),
|
||||
});
|
||||
type ConfigState = z.infer<typeof zConfigState>;
|
||||
|
||||
const initialConfigState: AppConfig & { didLoad: boolean } = {
|
||||
const getInitialState = (): ConfigState => ({
|
||||
...getDefaultAppConfig(),
|
||||
didLoad: false,
|
||||
isLocal: true,
|
||||
shouldUpdateImagesOnConnect: false,
|
||||
shouldFetchMetadataFromApi: false,
|
||||
allowPrivateBoards: false,
|
||||
allowPrivateStylePresets: false,
|
||||
allowClientSideUpload: false,
|
||||
allowPublishWorkflows: false,
|
||||
allowPromptExpansion: false,
|
||||
shouldShowCredits: false,
|
||||
disabledTabs: [],
|
||||
disabledFeatures: ['lightbox', 'faceRestore', 'batches'],
|
||||
disabledSDFeatures: ['variation', 'symmetry', 'hires', 'perlinNoise', 'noiseThreshold'],
|
||||
nodesAllowlist: undefined,
|
||||
nodesDenylist: undefined,
|
||||
sd: {
|
||||
disabledControlNetModels: [],
|
||||
disabledControlNetProcessors: [],
|
||||
iterations: {
|
||||
initial: 1,
|
||||
sliderMin: 1,
|
||||
sliderMax: 1000,
|
||||
numberInputMin: 1,
|
||||
numberInputMax: 10000,
|
||||
fineStep: 1,
|
||||
coarseStep: 1,
|
||||
},
|
||||
width: { ...baseDimensionConfig },
|
||||
height: { ...baseDimensionConfig },
|
||||
boundingBoxWidth: { ...baseDimensionConfig },
|
||||
boundingBoxHeight: { ...baseDimensionConfig },
|
||||
scaledBoundingBoxWidth: { ...baseDimensionConfig },
|
||||
scaledBoundingBoxHeight: { ...baseDimensionConfig },
|
||||
scheduler: 'dpmpp_3m_k',
|
||||
vaePrecision: 'fp32',
|
||||
steps: {
|
||||
initial: 30,
|
||||
sliderMin: 1,
|
||||
sliderMax: 100,
|
||||
numberInputMin: 1,
|
||||
numberInputMax: 500,
|
||||
fineStep: 1,
|
||||
coarseStep: 1,
|
||||
},
|
||||
guidance: {
|
||||
initial: 7,
|
||||
sliderMin: 1,
|
||||
sliderMax: 20,
|
||||
numberInputMin: 1,
|
||||
numberInputMax: 200,
|
||||
fineStep: 0.1,
|
||||
coarseStep: 0.5,
|
||||
},
|
||||
img2imgStrength: {
|
||||
initial: 0.7,
|
||||
sliderMin: 0,
|
||||
sliderMax: 1,
|
||||
numberInputMin: 0,
|
||||
numberInputMax: 1,
|
||||
fineStep: 0.01,
|
||||
coarseStep: 0.05,
|
||||
},
|
||||
canvasCoherenceStrength: {
|
||||
initial: 0.3,
|
||||
sliderMin: 0,
|
||||
sliderMax: 1,
|
||||
numberInputMin: 0,
|
||||
numberInputMax: 1,
|
||||
fineStep: 0.01,
|
||||
coarseStep: 0.05,
|
||||
},
|
||||
hrfStrength: {
|
||||
initial: 0.45,
|
||||
sliderMin: 0,
|
||||
sliderMax: 1,
|
||||
numberInputMin: 0,
|
||||
numberInputMax: 1,
|
||||
fineStep: 0.01,
|
||||
coarseStep: 0.05,
|
||||
},
|
||||
canvasCoherenceEdgeSize: {
|
||||
initial: 16,
|
||||
sliderMin: 0,
|
||||
sliderMax: 128,
|
||||
numberInputMin: 0,
|
||||
numberInputMax: 1024,
|
||||
fineStep: 8,
|
||||
coarseStep: 16,
|
||||
},
|
||||
cfgRescaleMultiplier: {
|
||||
initial: 0,
|
||||
sliderMin: 0,
|
||||
sliderMax: 0.99,
|
||||
numberInputMin: 0,
|
||||
numberInputMax: 0.99,
|
||||
fineStep: 0.05,
|
||||
coarseStep: 0.1,
|
||||
},
|
||||
clipSkip: {
|
||||
initial: 0,
|
||||
sliderMin: 0,
|
||||
sliderMax: 12, // determined by model selection, unused in practice
|
||||
numberInputMin: 0,
|
||||
numberInputMax: 12, // determined by model selection, unused in practice
|
||||
fineStep: 1,
|
||||
coarseStep: 1,
|
||||
},
|
||||
infillPatchmatchDownscaleSize: {
|
||||
initial: 1,
|
||||
sliderMin: 1,
|
||||
sliderMax: 10,
|
||||
numberInputMin: 1,
|
||||
numberInputMax: 10,
|
||||
fineStep: 1,
|
||||
coarseStep: 1,
|
||||
},
|
||||
infillTileSize: {
|
||||
initial: 32,
|
||||
sliderMin: 16,
|
||||
sliderMax: 64,
|
||||
numberInputMin: 16,
|
||||
numberInputMax: 256,
|
||||
fineStep: 1,
|
||||
coarseStep: 1,
|
||||
},
|
||||
maskBlur: {
|
||||
initial: 16,
|
||||
sliderMin: 0,
|
||||
sliderMax: 128,
|
||||
numberInputMin: 0,
|
||||
numberInputMax: 512,
|
||||
fineStep: 1,
|
||||
coarseStep: 1,
|
||||
},
|
||||
ca: {
|
||||
weight: {
|
||||
initial: 1,
|
||||
sliderMin: 0,
|
||||
sliderMax: 2,
|
||||
numberInputMin: -1,
|
||||
numberInputMax: 2,
|
||||
fineStep: 0.01,
|
||||
coarseStep: 0.05,
|
||||
},
|
||||
},
|
||||
dynamicPrompts: {
|
||||
maxPrompts: {
|
||||
initial: 100,
|
||||
sliderMin: 1,
|
||||
sliderMax: 1000,
|
||||
numberInputMin: 1,
|
||||
numberInputMax: 10000,
|
||||
fineStep: 1,
|
||||
coarseStep: 10,
|
||||
},
|
||||
},
|
||||
},
|
||||
flux: {
|
||||
guidance: {
|
||||
initial: 4,
|
||||
sliderMin: 2,
|
||||
sliderMax: 6,
|
||||
numberInputMin: 1,
|
||||
numberInputMax: 20,
|
||||
fineStep: 0.1,
|
||||
coarseStep: 0.5,
|
||||
},
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
export const configSlice = createSlice({
|
||||
const slice = createSlice({
|
||||
name: 'config',
|
||||
initialState: initialConfigState,
|
||||
initialState: getInitialState(),
|
||||
reducers: {
|
||||
configChanged: (state, action: PayloadAction<PartialAppConfig>) => {
|
||||
merge(state, action.payload);
|
||||
@ -196,11 +28,16 @@ export const configSlice = createSlice({
|
||||
},
|
||||
});
|
||||
|
||||
export const { configChanged } = configSlice.actions;
|
||||
export const { configChanged } = slice.actions;
|
||||
|
||||
export const configSliceConfig: SliceConfig<typeof slice> = {
|
||||
slice,
|
||||
schema: zConfigState,
|
||||
getInitialState,
|
||||
};
|
||||
|
||||
export const selectConfigSlice = (state: RootState) => state.config;
|
||||
const createConfigSelector = <T>(selector: Selector<typeof initialConfigState, T>) =>
|
||||
createSelector(selectConfigSlice, selector);
|
||||
const createConfigSelector = <T>(selector: Selector<ConfigState, T>) => createSelector(selectConfigSlice, selector);
|
||||
|
||||
export const selectWidthConfig = createConfigSelector((config) => config.sd.width);
|
||||
export const selectHeightConfig = createConfigSelector((config) => config.sd.height);
|
||||
|
@ -3,12 +3,15 @@ import { createSelector, createSlice } from '@reduxjs/toolkit';
|
||||
import type { LogNamespace } from 'app/logging/logger';
|
||||
import { zLogNamespace } from 'app/logging/logger';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { SliceConfig } from 'app/store/types';
|
||||
import { isPlainObject } from 'es-toolkit';
|
||||
import { uniq } from 'es-toolkit/compat';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import type { Language, SystemState } from './types';
|
||||
import { type Language, type SystemState, zSystemState } from './types';
|
||||
|
||||
const initialSystemState: SystemState = {
|
||||
const getInitialState = (): SystemState => ({
|
||||
_version: 2,
|
||||
shouldConfirmOnDelete: true,
|
||||
shouldAntialiasProgressImage: false,
|
||||
@ -23,11 +26,11 @@ const initialSystemState: SystemState = {
|
||||
logNamespaces: [...zLogNamespace.options],
|
||||
shouldShowInvocationProgressDetail: false,
|
||||
shouldHighlightFocusedRegions: false,
|
||||
};
|
||||
});
|
||||
|
||||
export const systemSlice = createSlice({
|
||||
const slice = createSlice({
|
||||
name: 'system',
|
||||
initialState: initialSystemState,
|
||||
initialState: getInitialState(),
|
||||
reducers: {
|
||||
setShouldConfirmOnDelete: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldConfirmOnDelete = action.payload;
|
||||
@ -89,25 +92,25 @@ export const {
|
||||
shouldConfirmOnNewSessionToggled,
|
||||
setShouldShowInvocationProgressDetail,
|
||||
setShouldHighlightFocusedRegions,
|
||||
} = systemSlice.actions;
|
||||
} = slice.actions;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrateSystemState = (state: any): any => {
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
if (state._version === 1) {
|
||||
state.language = (state as SystemState).language.replace('_', '-');
|
||||
state._version = 2;
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
||||
export const systemPersistConfig: PersistConfig<SystemState> = {
|
||||
name: systemSlice.name,
|
||||
initialState: initialSystemState,
|
||||
migrate: migrateSystemState,
|
||||
persistDenylist: [],
|
||||
export const systemSliceConfig: SliceConfig<typeof slice> = {
|
||||
slice,
|
||||
schema: zSystemState,
|
||||
getInitialState,
|
||||
persistConfig: {
|
||||
migrate: (state) => {
|
||||
assert(isPlainObject(state));
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
if (state._version === 1) {
|
||||
state.language = (state as SystemState).language.replace('_', '-');
|
||||
state._version = 2;
|
||||
}
|
||||
return zSystemState.parse(state);
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const selectSystemSlice = (state: RootState) => state.system;
|
||||
|
@ -1,4 +1,4 @@
|
||||
import type { LogLevel, LogNamespace } from 'app/logging/logger';
|
||||
import { zLogLevel, zLogNamespace } from 'app/logging/logger';
|
||||
import { z } from 'zod';
|
||||
|
||||
const zLanguage = z.enum([
|
||||
@ -29,19 +29,20 @@ const zLanguage = z.enum([
|
||||
export type Language = z.infer<typeof zLanguage>;
|
||||
export const isLanguage = (v: unknown): v is Language => zLanguage.safeParse(v).success;
|
||||
|
||||
export interface SystemState {
|
||||
_version: 2;
|
||||
shouldConfirmOnDelete: boolean;
|
||||
shouldAntialiasProgressImage: boolean;
|
||||
shouldConfirmOnNewSession: boolean;
|
||||
language: Language;
|
||||
shouldUseNSFWChecker: boolean;
|
||||
shouldUseWatermarker: boolean;
|
||||
shouldEnableInformationalPopovers: boolean;
|
||||
shouldEnableModelDescriptions: boolean;
|
||||
logIsEnabled: boolean;
|
||||
logLevel: LogLevel;
|
||||
logNamespaces: LogNamespace[];
|
||||
shouldShowInvocationProgressDetail: boolean;
|
||||
shouldHighlightFocusedRegions: boolean;
|
||||
}
|
||||
export const zSystemState = z.object({
|
||||
_version: z.literal(2),
|
||||
shouldConfirmOnDelete: z.boolean(),
|
||||
shouldAntialiasProgressImage: z.boolean(),
|
||||
shouldConfirmOnNewSession: z.boolean(),
|
||||
language: zLanguage,
|
||||
shouldUseNSFWChecker: z.boolean(),
|
||||
shouldUseWatermarker: z.boolean(),
|
||||
shouldEnableInformationalPopovers: z.boolean(),
|
||||
shouldEnableModelDescriptions: z.boolean(),
|
||||
logIsEnabled: z.boolean(),
|
||||
logLevel: zLogLevel,
|
||||
logNamespaces: z.array(zLogNamespace),
|
||||
shouldShowInvocationProgressDetail: z.boolean(),
|
||||
shouldHighlightFocusedRegions: z.boolean(),
|
||||
});
|
||||
export type SystemState = z.infer<typeof zSystemState>;
|
||||
|
@ -1,6 +1,7 @@
|
||||
import { Box, Button, ButtonGroup, Flex, Grid, Heading, Icon, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
|
||||
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
|
||||
import { setUpscaleInitialImageDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
import {
|
||||
@ -37,7 +38,7 @@ export const UpscalingLaunchpadPanel = memo(() => {
|
||||
|
||||
const onUpload = useCallback(
|
||||
(imageDTO: ImageDTO) => {
|
||||
dispatch(upscaleInitialImageChanged(imageDTO));
|
||||
dispatch(upscaleInitialImageChanged(imageDTOToImageWithDims(imageDTO)));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
@ -1,11 +1,13 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { SliceConfig } from 'app/store/types';
|
||||
import { isPlainObject } from 'es-toolkit';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import type { UIState } from './uiTypes';
|
||||
import { getInitialUIState } from './uiTypes';
|
||||
import { getInitialUIState, type UIState, zUIState } from './uiTypes';
|
||||
|
||||
export const uiSlice = createSlice({
|
||||
const slice = createSlice({
|
||||
name: 'ui',
|
||||
initialState: getInitialUIState(),
|
||||
reducers: {
|
||||
@ -81,29 +83,30 @@ export const {
|
||||
textAreaSizesStateChanged,
|
||||
dockviewStorageKeyChanged,
|
||||
pickerCompactViewStateChanged,
|
||||
} = uiSlice.actions;
|
||||
} = slice.actions;
|
||||
|
||||
export const selectUiSlice = (state: RootState) => state.ui;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrateUIState = (state: any): any => {
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
if (state._version === 1) {
|
||||
state.activeTab = 'generation';
|
||||
state._version = 2;
|
||||
}
|
||||
if (state._version === 2) {
|
||||
state.activeTab = 'canvas';
|
||||
state._version = 3;
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
||||
export const uiPersistConfig: PersistConfig<UIState> = {
|
||||
name: uiSlice.name,
|
||||
initialState: getInitialUIState(),
|
||||
migrate: migrateUIState,
|
||||
persistDenylist: ['shouldShowImageDetails'],
|
||||
export const uiSliceConfig: SliceConfig<typeof slice> = {
|
||||
slice,
|
||||
schema: zUIState,
|
||||
getInitialState: getInitialUIState,
|
||||
persistConfig: {
|
||||
migrate: (state) => {
|
||||
assert(isPlainObject(state));
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
if (state._version === 1) {
|
||||
state.activeTab = 'generation';
|
||||
state._version = 2;
|
||||
}
|
||||
if (state._version === 2) {
|
||||
state.activeTab = 'canvas';
|
||||
state._version = 3;
|
||||
}
|
||||
return zUIState.parse(state);
|
||||
},
|
||||
persistDenylist: ['shouldShowImageDetails'],
|
||||
},
|
||||
};
|
||||
|
@ -1,8 +1,7 @@
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { isPlainObject } from 'es-toolkit';
|
||||
import { z } from 'zod';
|
||||
|
||||
const zTabName = z.enum(['generate', 'canvas', 'upscaling', 'workflows', 'models', 'queue']);
|
||||
export const zTabName = z.enum(['generate', 'canvas', 'upscaling', 'workflows', 'models', 'queue']);
|
||||
export type TabName = z.infer<typeof zTabName>;
|
||||
|
||||
const zPartialDimensions = z.object({
|
||||
@ -13,18 +12,28 @@ const zPartialDimensions = z.object({
|
||||
const zSerializable = z.any().refine(isPlainObject);
|
||||
export type Serializable = z.infer<typeof zSerializable>;
|
||||
|
||||
const zUIState = z.object({
|
||||
_version: z.literal(3).default(3),
|
||||
activeTab: zTabName.default('generate'),
|
||||
shouldShowImageDetails: z.boolean().default(false),
|
||||
shouldShowProgressInViewer: z.boolean().default(true),
|
||||
accordions: z.record(z.string(), z.boolean()).default(() => ({})),
|
||||
expanders: z.record(z.string(), z.boolean()).default(() => ({})),
|
||||
textAreaSizes: z.record(z.string(), zPartialDimensions).default({}),
|
||||
panels: z.record(z.string(), zSerializable).default({}),
|
||||
shouldShowNotificationV2: z.boolean().default(true),
|
||||
pickerCompactViewStates: z.record(z.string(), z.boolean()).default(() => ({})),
|
||||
export const zUIState = z.object({
|
||||
_version: z.literal(3),
|
||||
activeTab: zTabName,
|
||||
shouldShowImageDetails: z.boolean(),
|
||||
shouldShowProgressInViewer: z.boolean(),
|
||||
accordions: z.record(z.string(), z.boolean()),
|
||||
expanders: z.record(z.string(), z.boolean()),
|
||||
textAreaSizes: z.record(z.string(), zPartialDimensions),
|
||||
panels: z.record(z.string(), zSerializable),
|
||||
shouldShowNotificationV2: z.boolean(),
|
||||
pickerCompactViewStates: z.record(z.string(), z.boolean()),
|
||||
});
|
||||
const INITIAL_STATE = zUIState.parse({});
|
||||
export type UIState = z.infer<typeof zUIState>;
|
||||
export const getInitialUIState = (): UIState => deepClone(INITIAL_STATE);
|
||||
export const getInitialUIState = (): UIState => ({
|
||||
_version: 3 as const,
|
||||
activeTab: 'generate' as const,
|
||||
shouldShowImageDetails: false,
|
||||
shouldShowProgressInViewer: true,
|
||||
accordions: {},
|
||||
expanders: {},
|
||||
textAreaSizes: {},
|
||||
panels: {},
|
||||
shouldShowNotificationV2: true,
|
||||
pickerCompactViewStates: {},
|
||||
});
|
||||
|
@ -1,5 +1,6 @@
|
||||
import { $openAPISchemaUrl } from 'app/store/nanostores/openAPISchemaUrl';
|
||||
import type { OpenAPIV3_1 } from 'openapi-types';
|
||||
import type { stringify } from 'querystring';
|
||||
import type { paths } from 'services/api/schema';
|
||||
import type { AppConfig, AppVersion } from 'services/api/types';
|
||||
|
||||
@ -11,7 +12,8 @@ import { api, buildV1Url } from '..';
|
||||
* buildAppInfoUrl('some-path')
|
||||
* // '/api/v1/app/some-path'
|
||||
*/
|
||||
const buildAppInfoUrl = (path: string = '') => buildV1Url(`app/${path}`);
|
||||
export const buildAppInfoUrl = (path: string = '', query?: Parameters<typeof stringify>[0]) =>
|
||||
buildV1Url(`app/${path}`, query);
|
||||
|
||||
export const appInfoApi = api.injectEndpoints({
|
||||
endpoints: (build) => ({
|
||||
@ -87,6 +89,31 @@ export const appInfoApi = api.injectEndpoints({
|
||||
},
|
||||
providesTags: ['Schema'],
|
||||
}),
|
||||
getClientStateByKey: build.query<
|
||||
paths['/api/v1/app/client_state']['get']['responses']['200']['content']['application/json'],
|
||||
paths['/api/v1/app/client_state']['get']['parameters']['query']
|
||||
>({
|
||||
query: () => ({
|
||||
url: buildAppInfoUrl('client_state'),
|
||||
method: 'GET',
|
||||
}),
|
||||
}),
|
||||
setClientStateByKey: build.mutation<
|
||||
paths['/api/v1/app/client_state']['post']['responses']['200']['content']['application/json'],
|
||||
paths['/api/v1/app/client_state']['post']['requestBody']['content']['application/json']
|
||||
>({
|
||||
query: (body) => ({
|
||||
url: buildAppInfoUrl('client_state'),
|
||||
method: 'POST',
|
||||
body,
|
||||
}),
|
||||
}),
|
||||
deleteClientState: build.mutation<void, void>({
|
||||
query: () => ({
|
||||
url: buildAppInfoUrl('client_state'),
|
||||
method: 'DELETE',
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
|
@ -57,13 +57,18 @@ const tagTypes = [
|
||||
// This is invalidated on reconnect. It should be used for queries that have changing data,
|
||||
// especially related to the queue and generation.
|
||||
'FetchOnReconnect',
|
||||
'ClientState',
|
||||
] as const;
|
||||
export type ApiTagDescription = TagDescription<(typeof tagTypes)[number]>;
|
||||
export const LIST_TAG = 'LIST';
|
||||
export const LIST_ALL_TAG = 'LIST_ALL';
|
||||
|
||||
const dynamicBaseQuery: BaseQueryFn<string | FetchArgs, unknown, FetchBaseQueryError> = (args, api, extraOptions) => {
|
||||
export const getBaseUrl = (): string => {
|
||||
const baseUrl = $baseUrl.get();
|
||||
return baseUrl || window.location.href.replace(/\/$/, '');
|
||||
};
|
||||
|
||||
const dynamicBaseQuery: BaseQueryFn<string | FetchArgs, unknown, FetchBaseQueryError> = (args, api, extraOptions) => {
|
||||
const authToken = $authToken.get();
|
||||
const projectId = $projectId.get();
|
||||
const isOpenAPIRequest =
|
||||
@ -71,7 +76,7 @@ const dynamicBaseQuery: BaseQueryFn<string | FetchArgs, unknown, FetchBaseQueryE
|
||||
(typeof args === 'string' && args.includes('openapi.json'));
|
||||
|
||||
const fetchBaseQueryArgs: FetchBaseQueryArgs = {
|
||||
baseUrl: baseUrl || window.location.href.replace(/\/$/, ''),
|
||||
baseUrl: getBaseUrl(),
|
||||
};
|
||||
|
||||
// When fetching the openapi.json, we need to remove circular references from the JSON.
|
||||
|
@ -1164,6 +1164,34 @@ export type paths = {
|
||||
patch?: never;
|
||||
trace?: never;
|
||||
};
|
||||
"/api/v1/app/client_state": {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path?: never;
|
||||
cookie?: never;
|
||||
};
|
||||
/**
|
||||
* Get Client State By Key
|
||||
* @description Gets the client state
|
||||
*/
|
||||
get: operations["get_client_state_by_key"];
|
||||
put?: never;
|
||||
/**
|
||||
* Set Client State
|
||||
* @description Sets the client state
|
||||
*/
|
||||
post: operations["set_client_state"];
|
||||
/**
|
||||
* Delete Client State
|
||||
* @description Deletes the client state
|
||||
*/
|
||||
delete: operations["delete_client_state"];
|
||||
options?: never;
|
||||
head?: never;
|
||||
patch?: never;
|
||||
trace?: never;
|
||||
};
|
||||
"/api/v1/queue/{queue_id}/enqueue_batch": {
|
||||
parameters: {
|
||||
query?: never;
|
||||
@ -24697,6 +24725,101 @@ export interface operations {
|
||||
};
|
||||
};
|
||||
};
|
||||
get_client_state_by_key: {
|
||||
parameters: {
|
||||
query: {
|
||||
/** @description Key to get */
|
||||
key: string;
|
||||
};
|
||||
header?: never;
|
||||
path?: never;
|
||||
cookie?: never;
|
||||
};
|
||||
requestBody?: never;
|
||||
responses: {
|
||||
/** @description Successful Response */
|
||||
200: {
|
||||
headers: {
|
||||
[name: string]: unknown;
|
||||
};
|
||||
content: {
|
||||
"application/json": components["schemas"]["JsonValue"] | null;
|
||||
};
|
||||
};
|
||||
/** @description Validation Error */
|
||||
422: {
|
||||
headers: {
|
||||
[name: string]: unknown;
|
||||
};
|
||||
content: {
|
||||
"application/json": components["schemas"]["HTTPValidationError"];
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
set_client_state: {
|
||||
parameters: {
|
||||
query: {
|
||||
/** @description Key to set */
|
||||
key: string;
|
||||
};
|
||||
header?: never;
|
||||
path?: never;
|
||||
cookie?: never;
|
||||
};
|
||||
requestBody: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["JsonValue"];
|
||||
};
|
||||
};
|
||||
responses: {
|
||||
/** @description Successful Response */
|
||||
200: {
|
||||
headers: {
|
||||
[name: string]: unknown;
|
||||
};
|
||||
content: {
|
||||
"application/json": unknown;
|
||||
};
|
||||
};
|
||||
/** @description Validation Error */
|
||||
422: {
|
||||
headers: {
|
||||
[name: string]: unknown;
|
||||
};
|
||||
content: {
|
||||
"application/json": components["schemas"]["HTTPValidationError"];
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
delete_client_state: {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path?: never;
|
||||
cookie?: never;
|
||||
};
|
||||
requestBody?: never;
|
||||
responses: {
|
||||
/** @description Successful Response */
|
||||
200: {
|
||||
headers: {
|
||||
[name: string]: unknown;
|
||||
};
|
||||
content: {
|
||||
"application/json": unknown;
|
||||
};
|
||||
};
|
||||
/** @description Client state deleted */
|
||||
204: {
|
||||
headers: {
|
||||
[name: string]: unknown;
|
||||
};
|
||||
content?: never;
|
||||
};
|
||||
};
|
||||
};
|
||||
enqueue_batch: {
|
||||
parameters: {
|
||||
query?: never;
|
||||
|
@ -1,6 +1,9 @@
|
||||
import type { Dimensions } from 'features/controlLayers/store/types';
|
||||
import type { components, paths } from 'services/api/schema';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
import type { JsonObject, SetRequired } from 'type-fest';
|
||||
import z from 'zod';
|
||||
|
||||
export type S = components['schemas'];
|
||||
|
||||
@ -33,10 +36,36 @@ export type InvocationJSONSchemaExtra = S['UIConfigBase'];
|
||||
export type AppVersion = S['AppVersion'];
|
||||
export type AppConfig = S['AppConfig'];
|
||||
|
||||
const zResourceOrigin = z.enum(['internal', 'external']);
|
||||
type ResourceOrigin = z.infer<typeof zResourceOrigin>;
|
||||
assert<Equals<ResourceOrigin, S['ResourceOrigin']>>();
|
||||
const zImageCategory = z.enum(['general', 'mask', 'control', 'user', 'other']);
|
||||
export type ImageCategory = z.infer<typeof zImageCategory>;
|
||||
assert<Equals<ImageCategory, S['ImageCategory']>>();
|
||||
|
||||
// Images
|
||||
export type ImageDTO = S['ImageDTO'];
|
||||
const _zImageDTO = z.object({
|
||||
image_name: z.string(),
|
||||
image_url: z.string(),
|
||||
thumbnail_url: z.string(),
|
||||
image_origin: zResourceOrigin,
|
||||
image_category: zImageCategory,
|
||||
width: z.number().int().gt(0),
|
||||
height: z.number().int().gt(0),
|
||||
created_at: z.string(),
|
||||
updated_at: z.string(),
|
||||
deleted_at: z.string().nullish(),
|
||||
is_intermediate: z.boolean(),
|
||||
session_id: z.string().nullish(),
|
||||
node_id: z.string().nullish(),
|
||||
starred: z.boolean(),
|
||||
has_workflow: z.boolean(),
|
||||
board_id: z.string().nullish(),
|
||||
});
|
||||
export type ImageDTO = z.infer<typeof _zImageDTO>;
|
||||
assert<Equals<ImageDTO, S['ImageDTO']>>();
|
||||
|
||||
export type BoardDTO = S['BoardDTO'];
|
||||
export type ImageCategory = S['ImageCategory'];
|
||||
export type OffsetPaginatedResults_ImageDTO_ = S['OffsetPaginatedResults_ImageDTO_'];
|
||||
|
||||
// Models
|
||||
@ -298,8 +327,13 @@ export type ModelInstallStatus = S['InstallStatus'];
|
||||
export type Graph = S['Graph'];
|
||||
export type NonNullableGraph = SetRequired<Graph, 'nodes' | 'edges'>;
|
||||
export type Batch = S['Batch'];
|
||||
export type WorkflowRecordOrderBy = S['WorkflowRecordOrderBy'];
|
||||
export type SQLiteDirection = S['SQLiteDirection'];
|
||||
export const zWorkflowRecordOrderBy = z.enum(['name', 'created_at', 'updated_at', 'opened_at']);
|
||||
export type WorkflowRecordOrderBy = z.infer<typeof zWorkflowRecordOrderBy>;
|
||||
assert<Equals<S['WorkflowRecordOrderBy'], WorkflowRecordOrderBy>>();
|
||||
|
||||
export const zSQLiteDirection = z.enum(['ASC', 'DESC']);
|
||||
export type SQLiteDirection = z.infer<typeof zSQLiteDirection>;
|
||||
assert<Equals<S['SQLiteDirection'], SQLiteDirection>>();
|
||||
export type WorkflowRecordListItemWithThumbnailDTO = S['WorkflowRecordListItemWithThumbnailDTO'];
|
||||
|
||||
type KeysOfUnion<T> = T extends T ? keyof T : never;
|
||||
|
@ -1,12 +1,12 @@
|
||||
import { ExternalLink } from '@invoke-ai/ui-library';
|
||||
import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { listenerMiddleware } from 'app/store/middleware/listenerMiddleware';
|
||||
import { socketConnected } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
|
||||
import { $baseUrl } from 'app/store/nanostores/baseUrl';
|
||||
import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId';
|
||||
import { $queueId } from 'app/store/nanostores/queueId';
|
||||
import type { AppStore } from 'app/store/store';
|
||||
import { listenerMiddleware } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { forEach, isNil, round } from 'es-toolkit/compat';
|
||||
import {
|
||||
|
@ -1 +1 @@
|
||||
__version__ = "6.2.0"
|
||||
__version__ = "6.3.0a1"
|
||||
|
@ -67,6 +67,7 @@ def mock_services() -> InvocationServices:
|
||||
workflow_thumbnails=None, # type: ignore
|
||||
model_relationship_records=None, # type: ignore
|
||||
model_relationships=None, # type: ignore
|
||||
client_state_persistence=None, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user