mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes, ui): wip image types
This commit is contained in:
parent
bdab73701f
commit
9317b42e5f
@ -34,11 +34,10 @@ async def upload_image(
|
||||
file: UploadFile,
|
||||
request: Request,
|
||||
response: Response,
|
||||
image_category: ImageCategory = Query(
|
||||
default=ImageCategory.GENERAL, description="The category of the image"
|
||||
),
|
||||
is_intermediate: bool = Query(
|
||||
default=False, description="Whether this is an intermediate image"
|
||||
image_category: ImageCategory = Query(description="The category of the image"),
|
||||
is_intermediate: bool = Query(description="Whether this is an intermediate image"),
|
||||
show_in_gallery: bool = Query(
|
||||
description="Whether this image should be shown in the gallery"
|
||||
),
|
||||
session_id: Optional[str] = Query(
|
||||
default=None, description="The session ID associated with this upload, if any"
|
||||
@ -63,6 +62,7 @@ async def upload_image(
|
||||
image_category=image_category,
|
||||
session_id=session_id,
|
||||
is_intermediate=is_intermediate,
|
||||
show_in_gallery=show_in_gallery,
|
||||
)
|
||||
|
||||
response.status_code = 201
|
||||
@ -228,24 +228,30 @@ async def get_image_urls(
|
||||
response_model=PaginatedResults[ImageDTO],
|
||||
)
|
||||
async def list_images_with_metadata(
|
||||
image_type: ImageType = Query(description="The type of images to list"),
|
||||
image_category: ImageCategory = Query(description="The kind of images to list"),
|
||||
is_intermediate: bool = Query(
|
||||
default=False, description="Whether to list intermediate images"
|
||||
image_type: Optional[ImageType] = Query(
|
||||
default=None, description="The type of images to list"
|
||||
),
|
||||
page: int = Query(default=0, description="The page of image metadata to get"),
|
||||
per_page: int = Query(
|
||||
default=10, description="The number of image metadata per page"
|
||||
image_category: Optional[ImageCategory] = Query(
|
||||
default=None, description="The kind of images to list"
|
||||
),
|
||||
is_intermediate: Optional[bool] = Query(
|
||||
default=None, description="Whether to list intermediate images"
|
||||
),
|
||||
show_in_gallery: Optional[bool] = Query(
|
||||
default=None, description="Whether to list images that show in the gallery"
|
||||
),
|
||||
page: int = Query(default=0, description="The page of images to get"),
|
||||
per_page: int = Query(default=10, description="The number of images per page"),
|
||||
) -> PaginatedResults[ImageDTO]:
|
||||
"""Gets a list of images with metadata"""
|
||||
"""Gets a list of images"""
|
||||
|
||||
image_dtos = ApiDependencies.invoker.services.images.get_many(
|
||||
page,
|
||||
per_page,
|
||||
image_type,
|
||||
image_category,
|
||||
is_intermediate,
|
||||
page,
|
||||
per_page,
|
||||
show_in_gallery,
|
||||
)
|
||||
|
||||
return image_dtos
|
||||
|
@ -63,11 +63,12 @@ class ImageRecordStorageBase(ABC):
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
is_intermediate: bool = False,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
image_type: Optional[ImageType] = None,
|
||||
image_category: Optional[ImageCategory] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
show_in_gallery: Optional[bool] = None,
|
||||
) -> PaginatedResults[ImageRecord]:
|
||||
"""Gets a page of image records."""
|
||||
pass
|
||||
@ -91,6 +92,7 @@ class ImageRecordStorageBase(ABC):
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[ImageMetadata],
|
||||
is_intermediate: bool = False,
|
||||
show_in_gallery: bool = True,
|
||||
) -> datetime:
|
||||
"""Saves an image record."""
|
||||
pass
|
||||
@ -137,6 +139,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
session_id TEXT,
|
||||
node_id TEXT,
|
||||
metadata TEXT,
|
||||
show_in_gallery BOOLEAN DEFAULT TRUE,
|
||||
is_intermediate BOOLEAN DEFAULT FALSE,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
-- Updated via trigger
|
||||
@ -224,7 +227,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
""",
|
||||
(changes.image_category, image_name),
|
||||
)
|
||||
|
||||
|
||||
# Change the session associated with the image
|
||||
if changes.session_id is not None:
|
||||
self._cursor.execute(
|
||||
@ -244,36 +247,72 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
is_intermediate: bool = False,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
image_type: Optional[ImageType] = None,
|
||||
image_category: Optional[ImageCategory] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
show_in_gallery: Optional[bool] = None,
|
||||
) -> PaginatedResults[ImageRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT * FROM images
|
||||
WHERE image_type = ? AND image_category = ? AND is_intermediate = ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ? OFFSET ?;
|
||||
""",
|
||||
(image_type.value, image_category.value, is_intermediate, per_page, page * per_page),
|
||||
)
|
||||
# Manually build two queries - one for the count, one for the records
|
||||
|
||||
count_query = """--sql
|
||||
SELECT COUNT(*) FROM images WHERE 1=1
|
||||
"""
|
||||
|
||||
images_query = """--sql
|
||||
SELECT * FROM images WHERE 1=1
|
||||
"""
|
||||
|
||||
query_conditions = ""
|
||||
query_params = []
|
||||
|
||||
if image_type is not None:
|
||||
query_conditions += """--sql
|
||||
AND image_type = ?
|
||||
"""
|
||||
query_params.append(image_type.value)
|
||||
|
||||
if image_category is not None:
|
||||
query_conditions += """--sql
|
||||
AND image_category = ?
|
||||
"""
|
||||
query_params.append(image_category.value)
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += """--sql
|
||||
AND is_intermediate = ?
|
||||
"""
|
||||
query_params.append(is_intermediate)
|
||||
|
||||
if show_in_gallery is not None:
|
||||
query_conditions += """--sql
|
||||
AND show_in_gallery = ?
|
||||
"""
|
||||
query_params.append(show_in_gallery)
|
||||
|
||||
query_pagination = """--sql
|
||||
ORDER BY created_at DESC LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
count_query += query_conditions + ";"
|
||||
count_params = query_params.copy()
|
||||
|
||||
images_query += query_conditions + query_pagination + ";"
|
||||
images_params = query_params.copy()
|
||||
images_params.append(per_page)
|
||||
images_params.append(page * per_page)
|
||||
|
||||
self._cursor.execute(images_query, images_params)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
|
||||
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*) FROM images
|
||||
WHERE image_type = ? AND image_category = ? AND is_intermediate = ?
|
||||
""",
|
||||
(image_type.value, image_category.value, is_intermediate),
|
||||
)
|
||||
self._cursor.execute(count_query, count_params)
|
||||
|
||||
count = self._cursor.fetchone()[0]
|
||||
except sqlite3.Error as e:
|
||||
@ -316,6 +355,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[ImageMetadata],
|
||||
is_intermediate: bool = False,
|
||||
show_in_gallery: bool = True,
|
||||
) -> datetime:
|
||||
try:
|
||||
metadata_json = (
|
||||
@ -333,9 +373,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
is_intermediate
|
||||
is_intermediate,
|
||||
show_in_gallery
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
image_name,
|
||||
@ -347,6 +388,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
session_id,
|
||||
metadata_json,
|
||||
is_intermediate,
|
||||
show_in_gallery,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
@ -1,8 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from os import name
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
import uuid
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.models.image import (
|
||||
@ -51,6 +49,7 @@ class ImageServiceABC(ABC):
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
intermediate: bool = False,
|
||||
show_in_gallery: bool = True,
|
||||
) -> ImageDTO:
|
||||
"""Creates an image, storing the file and its metadata."""
|
||||
pass
|
||||
@ -100,10 +99,12 @@ class ImageServiceABC(ABC):
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
image_type: Optional[ImageType] = None,
|
||||
image_category: Optional[ImageCategory] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
show_in_gallery: Optional[bool] = None,
|
||||
) -> PaginatedResults[ImageDTO]:
|
||||
"""Gets a paginated list of image DTOs."""
|
||||
pass
|
||||
@ -175,6 +176,7 @@ class ImageService(ImageServiceABC):
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
is_intermediate: bool = False,
|
||||
show_in_gallery: bool = True,
|
||||
) -> ImageDTO:
|
||||
if image_type not in ImageType:
|
||||
raise InvalidImageTypeException
|
||||
@ -199,6 +201,7 @@ class ImageService(ImageServiceABC):
|
||||
height=height,
|
||||
# Meta fields
|
||||
is_intermediate=is_intermediate,
|
||||
show_in_gallery=show_in_gallery,
|
||||
# Nullable fields
|
||||
node_id=node_id,
|
||||
session_id=session_id,
|
||||
@ -233,6 +236,7 @@ class ImageService(ImageServiceABC):
|
||||
updated_at=created_at, # this is always the same as the created_at at this time
|
||||
deleted_at=None,
|
||||
is_intermediate=is_intermediate,
|
||||
show_in_gallery=show_in_gallery,
|
||||
# Extra non-nullable fields for DTO
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
@ -328,28 +332,30 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
is_intermediate: bool = False,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
image_type: Optional[ImageType] = None,
|
||||
image_category: Optional[ImageCategory] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
show_in_gallery: Optional[bool] = None,
|
||||
) -> PaginatedResults[ImageDTO]:
|
||||
try:
|
||||
results = self._services.records.get_many(
|
||||
page,
|
||||
per_page,
|
||||
image_type,
|
||||
image_category,
|
||||
is_intermediate,
|
||||
page,
|
||||
per_page,
|
||||
show_in_gallery,
|
||||
)
|
||||
|
||||
image_dtos = list(
|
||||
map(
|
||||
lambda r: image_record_to_dto(
|
||||
r,
|
||||
self._services.urls.get_image_url(image_type, r.image_name),
|
||||
self._services.urls.get_image_url(r.image_type, r.image_name),
|
||||
self._services.urls.get_image_url(
|
||||
image_type, r.image_name, True
|
||||
r.image_type, r.image_name, True
|
||||
),
|
||||
),
|
||||
results.items,
|
||||
|
@ -33,6 +33,8 @@ class ImageRecord(BaseModel):
|
||||
"""The deleted timestamp of the image."""
|
||||
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
|
||||
"""Whether this is an intermediate image."""
|
||||
show_in_gallery: bool = Field(description="Whether this image should be shown in the gallery.")
|
||||
"""Whether this image should be shown in the gallery."""
|
||||
session_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The session ID that generated this image, if it is a generated image.",
|
||||
@ -117,6 +119,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
updated_at = image_dict.get("updated_at", get_iso_timestamp())
|
||||
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
|
||||
is_intermediate = image_dict.get("is_intermediate", False)
|
||||
show_in_gallery = image_dict.get("show_in_gallery", True)
|
||||
|
||||
raw_metadata = image_dict.get("metadata")
|
||||
|
||||
@ -138,4 +141,5 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
updated_at=updated_at,
|
||||
deleted_at=deleted_at,
|
||||
is_intermediate=is_intermediate,
|
||||
show_in_gallery=show_in_gallery,
|
||||
)
|
||||
|
@ -60,13 +60,13 @@ import {
|
||||
addSessionCanceledRejectedListener,
|
||||
} from './listeners/sessionCanceled';
|
||||
import {
|
||||
addReceivedResultImagesPageFulfilledListener,
|
||||
addReceivedResultImagesPageRejectedListener,
|
||||
} from './listeners/receivedResultImagesPage';
|
||||
addReceivedGalleryImagesFulfilledListener,
|
||||
addReceivedGalleryImagesRejectedListener,
|
||||
} from './listeners/receivedGalleryImages';
|
||||
import {
|
||||
addReceivedUploadImagesPageFulfilledListener,
|
||||
addReceivedUploadImagesPageRejectedListener,
|
||||
} from './listeners/receivedUploadImagesPage';
|
||||
} from './listeners/receivedUploadImages';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
@ -146,7 +146,7 @@ addSessionCanceledFulfilledListener();
|
||||
addSessionCanceledRejectedListener();
|
||||
|
||||
// Gallery pages
|
||||
addReceivedResultImagesPageFulfilledListener();
|
||||
addReceivedResultImagesPageRejectedListener();
|
||||
addReceivedGalleryImagesFulfilledListener();
|
||||
addReceivedGalleryImagesRejectedListener();
|
||||
addReceivedUploadImagesPageFulfilledListener();
|
||||
addReceivedUploadImagesPageRejectedListener();
|
||||
|
@ -55,6 +55,9 @@ export const addCanvasMergedListener = () => {
|
||||
formData: {
|
||||
file: new File([blob], filename, { type: 'image/png' }),
|
||||
},
|
||||
imageCategory: 'general',
|
||||
isIntermediate: true,
|
||||
showInGallery: false,
|
||||
})
|
||||
);
|
||||
|
||||
|
@ -4,13 +4,15 @@ import { log } from 'app/logging/useLogger';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { resultUpserted } from 'features/gallery/store/resultsSlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' });
|
||||
|
||||
export const addCanvasSavedToGalleryListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: canvasSavedToGallery,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
effect: async (action, { dispatch, getState, take }) => {
|
||||
const state = getState();
|
||||
|
||||
const blob = await getBaseLayerBlob(state);
|
||||
@ -27,13 +29,26 @@ export const addCanvasSavedToGalleryListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
const filename = `mergedCanvas_${uuidv4()}.png`;
|
||||
|
||||
dispatch(
|
||||
imageUploaded({
|
||||
formData: {
|
||||
file: new File([blob], 'mergedCanvas.png', { type: 'image/png' }),
|
||||
file: new File([blob], filename, { type: 'image/png' }),
|
||||
},
|
||||
imageCategory: 'general',
|
||||
isIntermediate: false,
|
||||
showInGallery: true,
|
||||
})
|
||||
);
|
||||
|
||||
const [{ payload: uploadedImageDTO }] = await take(
|
||||
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||
imageUploaded.fulfilled.match(action) &&
|
||||
action.meta.arg.formData.file.name === filename
|
||||
);
|
||||
|
||||
dispatch(resultUpserted(uploadedImageDTO));
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -3,10 +3,7 @@ import { uploadUpserted } from 'features/gallery/store/uploadsSlice';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||
import { resultUpserted } from 'features/gallery/store/resultsSlice';
|
||||
import { isResultsImageDTO, isUploadsImageDTO } from 'services/types/guards';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'image' });
|
||||
@ -24,7 +21,7 @@ export const addImageUploadedFulfilledListener = () => {
|
||||
const state = getState();
|
||||
|
||||
// Handle uploads
|
||||
if (isUploadsImageDTO(image)) {
|
||||
if (!image.show_in_gallery && image.image_type === 'uploads') {
|
||||
dispatch(uploadUpserted(image));
|
||||
|
||||
dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
|
||||
@ -32,19 +29,11 @@ export const addImageUploadedFulfilledListener = () => {
|
||||
if (state.gallery.shouldAutoSwitchToNewImages) {
|
||||
dispatch(imageSelected(image));
|
||||
}
|
||||
|
||||
if (action.meta.arg.activeTabName === 'img2img') {
|
||||
dispatch(initialImageSelected(image));
|
||||
}
|
||||
|
||||
if (action.meta.arg.activeTabName === 'unifiedCanvas') {
|
||||
dispatch(setInitialCanvasImage(image));
|
||||
}
|
||||
}
|
||||
|
||||
// Handle results
|
||||
// TODO: Can this ever happen? I don't think so...
|
||||
if (isResultsImageDTO(image)) {
|
||||
if (image.show_in_gallery) {
|
||||
dispatch(resultUpserted(image));
|
||||
}
|
||||
},
|
||||
|
@ -1,31 +1,31 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { receivedResultImagesPage } from 'services/thunks/gallery';
|
||||
import { receivedGalleryImages } from 'services/thunks/gallery';
|
||||
import { serializeError } from 'serialize-error';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'gallery' });
|
||||
|
||||
export const addReceivedResultImagesPageFulfilledListener = () => {
|
||||
export const addReceivedGalleryImagesFulfilledListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: receivedResultImagesPage.fulfilled,
|
||||
actionCreator: receivedGalleryImages.fulfilled,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const page = action.payload;
|
||||
moduleLog.debug(
|
||||
{ data: { page } },
|
||||
`Received ${page.items.length} results`
|
||||
`Received ${page.items.length} gallery images`
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const addReceivedResultImagesPageRejectedListener = () => {
|
||||
export const addReceivedGalleryImagesRejectedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: receivedResultImagesPage.rejected,
|
||||
actionCreator: receivedGalleryImages.rejected,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
if (action.payload) {
|
||||
moduleLog.debug(
|
||||
{ data: { error: serializeError(action.payload.error) } },
|
||||
'Problem receiving results'
|
||||
'Problem receiving gallery images'
|
||||
);
|
||||
}
|
||||
},
|
@ -1,18 +1,18 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { receivedUploadImagesPage } from 'services/thunks/gallery';
|
||||
import { receivedUploadImages } from 'services/thunks/gallery';
|
||||
import { serializeError } from 'serialize-error';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'gallery' });
|
||||
|
||||
export const addReceivedUploadImagesPageFulfilledListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: receivedUploadImagesPage.fulfilled,
|
||||
actionCreator: receivedUploadImages.fulfilled,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const page = action.payload;
|
||||
moduleLog.debug(
|
||||
{ data: { page } },
|
||||
`Received ${page.items.length} uploads`
|
||||
`Received ${page.items.length} uploaded images`
|
||||
);
|
||||
},
|
||||
});
|
||||
@ -20,12 +20,12 @@ export const addReceivedUploadImagesPageFulfilledListener = () => {
|
||||
|
||||
export const addReceivedUploadImagesPageRejectedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: receivedUploadImagesPage.rejected,
|
||||
actionCreator: receivedUploadImages.rejected,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
if (action.payload) {
|
||||
moduleLog.debug(
|
||||
{ data: { error: serializeError(action.payload.error) } },
|
||||
'Problem receiving uploads'
|
||||
'Problem receiving uploaded images'
|
||||
);
|
||||
}
|
||||
},
|
@ -2,8 +2,8 @@ import { startAppListening } from '../..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { socketConnected } from 'services/events/actions';
|
||||
import {
|
||||
receivedResultImagesPage,
|
||||
receivedUploadImagesPage,
|
||||
receivedGalleryImages,
|
||||
receivedUploadImages,
|
||||
} from 'services/thunks/gallery';
|
||||
import { receivedModels } from 'services/thunks/model';
|
||||
import { receivedOpenAPISchema } from 'services/thunks/schema';
|
||||
@ -24,11 +24,11 @@ export const addSocketConnectedListener = () => {
|
||||
|
||||
// These thunks need to be dispatch in middleware; cannot handle in a reducer
|
||||
if (!results.ids.length) {
|
||||
dispatch(receivedResultImagesPage());
|
||||
dispatch(receivedGalleryImages());
|
||||
}
|
||||
|
||||
if (!uploads.ids.length) {
|
||||
dispatch(receivedUploadImagesPage());
|
||||
dispatch(receivedUploadImages());
|
||||
}
|
||||
|
||||
if (!models.ids.length) {
|
||||
|
@ -101,7 +101,9 @@ export const addUserInvokedCanvasListener = () => {
|
||||
formData: {
|
||||
file: new File([baseBlob], baseFilename, { type: 'image/png' }),
|
||||
},
|
||||
imageCategory: 'general',
|
||||
isIntermediate: true,
|
||||
showInGallery: false,
|
||||
})
|
||||
);
|
||||
|
||||
@ -127,7 +129,9 @@ export const addUserInvokedCanvasListener = () => {
|
||||
formData: {
|
||||
file: new File([maskBlob], maskFilename, { type: 'image/png' }),
|
||||
},
|
||||
imageCategory: 'mask',
|
||||
isIntermediate: true,
|
||||
showInGallery: false,
|
||||
})
|
||||
);
|
||||
|
||||
|
@ -4,7 +4,6 @@ import { useHotkeys } from 'react-hotkeys-hook';
|
||||
type ImageUploadOverlayProps = {
|
||||
isDragAccept: boolean;
|
||||
isDragReject: boolean;
|
||||
overlaySecondaryText: string;
|
||||
setIsHandlingUpload: (isHandlingUpload: boolean) => void;
|
||||
};
|
||||
|
||||
@ -12,7 +11,6 @@ const ImageUploadOverlay = (props: ImageUploadOverlayProps) => {
|
||||
const {
|
||||
isDragAccept,
|
||||
isDragReject: _isDragAccept,
|
||||
overlaySecondaryText,
|
||||
setIsHandlingUpload,
|
||||
} = props;
|
||||
|
||||
@ -48,7 +46,7 @@ const ImageUploadOverlay = (props: ImageUploadOverlayProps) => {
|
||||
}}
|
||||
>
|
||||
{isDragAccept ? (
|
||||
<Heading size="lg">Upload Image{overlaySecondaryText}</Heading>
|
||||
<Heading size="lg">Drop to Upload</Heading>
|
||||
) : (
|
||||
<>
|
||||
<Heading size="lg">Invalid Upload</Heading>
|
||||
|
@ -69,11 +69,13 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
||||
dispatch(
|
||||
imageUploaded({
|
||||
formData: { file },
|
||||
activeTabName,
|
||||
imageCategory: 'general',
|
||||
isIntermediate: false,
|
||||
showInGallery: false,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, activeTabName]
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onDrop = useCallback(
|
||||
@ -144,14 +146,6 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
||||
};
|
||||
}, [inputRef, open, setOpenUploaderFunction]);
|
||||
|
||||
const overlaySecondaryText = useMemo(() => {
|
||||
if (['img2img', 'unifiedCanvas'].includes(activeTabName)) {
|
||||
return ` to ${String(t(`common.${activeTabName}` as ResourceKey))}`;
|
||||
}
|
||||
|
||||
return '';
|
||||
}, [t, activeTabName]);
|
||||
|
||||
return (
|
||||
<Box
|
||||
{...getRootProps({ style: {} })}
|
||||
@ -166,7 +160,6 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
||||
<ImageUploadOverlay
|
||||
isDragAccept={isDragAccept}
|
||||
isDragReject={isDragReject}
|
||||
overlaySecondaryText={overlaySecondaryText}
|
||||
setIsHandlingUpload={setIsHandlingUpload}
|
||||
/>
|
||||
)}
|
||||
|
@ -43,8 +43,8 @@ import HoverableImage from './HoverableImage';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import { resultsAdapter } from '../store/resultsSlice';
|
||||
import {
|
||||
receivedResultImagesPage,
|
||||
receivedUploadImagesPage,
|
||||
receivedGalleryImages,
|
||||
receivedUploadImages,
|
||||
} from 'services/thunks/gallery';
|
||||
import { uploadsAdapter } from '../store/uploadsSlice';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
@ -151,11 +151,11 @@ const ImageGalleryContent = () => {
|
||||
|
||||
const handleClickLoadMore = () => {
|
||||
if (currentCategory === 'results') {
|
||||
dispatch(receivedResultImagesPage());
|
||||
dispatch(receivedGalleryImages());
|
||||
}
|
||||
|
||||
if (currentCategory === 'uploads') {
|
||||
dispatch(receivedUploadImagesPage());
|
||||
dispatch(receivedUploadImages());
|
||||
}
|
||||
};
|
||||
|
||||
@ -211,9 +211,9 @@ const ImageGalleryContent = () => {
|
||||
|
||||
const handleEndReached = useCallback(() => {
|
||||
if (currentCategory === 'results') {
|
||||
dispatch(receivedResultImagesPage());
|
||||
dispatch(receivedGalleryImages());
|
||||
} else if (currentCategory === 'uploads') {
|
||||
dispatch(receivedUploadImagesPage());
|
||||
dispatch(receivedUploadImages());
|
||||
}
|
||||
}, [dispatch, currentCategory]);
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import {
|
||||
receivedResultImagesPage,
|
||||
receivedUploadImagesPage,
|
||||
receivedGalleryImages,
|
||||
receivedUploadImages,
|
||||
} from '../../../services/thunks/gallery';
|
||||
import { ImageDTO } from 'services/api';
|
||||
|
||||
@ -60,7 +60,7 @@ export const gallerySlice = createSlice({
|
||||
},
|
||||
},
|
||||
extraReducers(builder) {
|
||||
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
|
||||
builder.addCase(receivedGalleryImages.fulfilled, (state, action) => {
|
||||
// rehydrate selectedImage URL when results list comes in
|
||||
// solves case when outdated URL is in local storage
|
||||
const selectedImage = state.selectedImage;
|
||||
@ -76,7 +76,7 @@ export const gallerySlice = createSlice({
|
||||
}
|
||||
}
|
||||
});
|
||||
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
|
||||
builder.addCase(receivedUploadImages.fulfilled, (state, action) => {
|
||||
// rehydrate selectedImage URL when results list comes in
|
||||
// solves case when outdated URL is in local storage
|
||||
const selectedImage = state.selectedImage;
|
||||
|
@ -5,7 +5,7 @@ import {
|
||||
} from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import {
|
||||
receivedResultImagesPage,
|
||||
receivedGalleryImages,
|
||||
IMAGES_PER_PAGE,
|
||||
} from 'services/thunks/gallery';
|
||||
import { ImageDTO } from 'services/api';
|
||||
@ -15,7 +15,7 @@ export type ResultsImageDTO = Omit<ImageDTO, 'image_type'> & {
|
||||
image_type: 'results';
|
||||
};
|
||||
|
||||
export const resultsAdapter = createEntityAdapter<ResultsImageDTO>({
|
||||
export const resultsAdapter = createEntityAdapter<ImageDTO>({
|
||||
selectId: (image) => image.image_name,
|
||||
sortComparer: (a, b) => dateComparator(b.created_at, a.created_at),
|
||||
});
|
||||
@ -43,7 +43,7 @@ const resultsSlice = createSlice({
|
||||
name: 'results',
|
||||
initialState: initialResultsState,
|
||||
reducers: {
|
||||
resultUpserted: (state, action: PayloadAction<ResultsImageDTO>) => {
|
||||
resultUpserted: (state, action: PayloadAction<ImageDTO>) => {
|
||||
resultsAdapter.upsertOne(state, action.payload);
|
||||
state.upsertedImageCount += 1;
|
||||
},
|
||||
@ -52,18 +52,18 @@ const resultsSlice = createSlice({
|
||||
/**
|
||||
* Received Result Images Page - PENDING
|
||||
*/
|
||||
builder.addCase(receivedResultImagesPage.pending, (state) => {
|
||||
builder.addCase(receivedGalleryImages.pending, (state) => {
|
||||
state.isLoading = true;
|
||||
});
|
||||
|
||||
/**
|
||||
* Received Result Images Page - FULFILLED
|
||||
*/
|
||||
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
|
||||
builder.addCase(receivedGalleryImages.fulfilled, (state, action) => {
|
||||
const { page, pages } = action.payload;
|
||||
|
||||
// We know these will all be of the results type, but it's not represented in the API types
|
||||
const items = action.payload.items as ResultsImageDTO[];
|
||||
const items = action.payload.items;
|
||||
|
||||
resultsAdapter.setMany(state, items);
|
||||
|
||||
|
@ -5,10 +5,7 @@ import {
|
||||
} from '@reduxjs/toolkit';
|
||||
|
||||
import { RootState } from 'app/store/store';
|
||||
import {
|
||||
receivedUploadImagesPage,
|
||||
IMAGES_PER_PAGE,
|
||||
} from 'services/thunks/gallery';
|
||||
import { receivedUploadImages, IMAGES_PER_PAGE } from 'services/thunks/gallery';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { dateComparator } from 'common/util/dateComparator';
|
||||
|
||||
@ -16,7 +13,7 @@ export type UploadsImageDTO = Omit<ImageDTO, 'image_type'> & {
|
||||
image_type: 'uploads';
|
||||
};
|
||||
|
||||
export const uploadsAdapter = createEntityAdapter<UploadsImageDTO>({
|
||||
export const uploadsAdapter = createEntityAdapter<ImageDTO>({
|
||||
selectId: (image) => image.image_name,
|
||||
sortComparer: (a, b) => dateComparator(b.created_at, a.created_at),
|
||||
});
|
||||
@ -44,7 +41,7 @@ const uploadsSlice = createSlice({
|
||||
name: 'uploads',
|
||||
initialState: initialUploadsState,
|
||||
reducers: {
|
||||
uploadUpserted: (state, action: PayloadAction<UploadsImageDTO>) => {
|
||||
uploadUpserted: (state, action: PayloadAction<ImageDTO>) => {
|
||||
uploadsAdapter.upsertOne(state, action.payload);
|
||||
state.upsertedImageCount += 1;
|
||||
},
|
||||
@ -53,18 +50,18 @@ const uploadsSlice = createSlice({
|
||||
/**
|
||||
* Received Upload Images Page - PENDING
|
||||
*/
|
||||
builder.addCase(receivedUploadImagesPage.pending, (state) => {
|
||||
builder.addCase(receivedUploadImages.pending, (state) => {
|
||||
state.isLoading = true;
|
||||
});
|
||||
|
||||
/**
|
||||
* Received Upload Images Page - FULFILLED
|
||||
*/
|
||||
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
|
||||
builder.addCase(receivedUploadImages.fulfilled, (state, action) => {
|
||||
const { page, pages } = action.payload;
|
||||
|
||||
// We know these will all be of the uploads type, but it's not represented in the API types
|
||||
const items = action.payload.items as UploadsImageDTO[];
|
||||
const items = action.payload.items;
|
||||
|
||||
uploadsAdapter.setMany(state, items);
|
||||
|
||||
|
@ -54,6 +54,10 @@ export type ImageDTO = {
|
||||
* Whether this is an intermediate image.
|
||||
*/
|
||||
is_intermediate: boolean;
|
||||
/**
|
||||
* Whether this image should be shown in the gallery.
|
||||
*/
|
||||
show_in_gallery: boolean;
|
||||
/**
|
||||
* The session ID that generated this image, if it is a generated image.
|
||||
*/
|
||||
|
@ -17,35 +17,40 @@ export class ImagesService {
|
||||
|
||||
/**
|
||||
* List Images With Metadata
|
||||
* Gets a list of images with metadata
|
||||
* Gets a list of images
|
||||
* @returns PaginatedResults_ImageDTO_ Successful Response
|
||||
* @throws ApiError
|
||||
*/
|
||||
public static listImagesWithMetadata({
|
||||
imageType,
|
||||
imageCategory,
|
||||
isIntermediate = false,
|
||||
isIntermediate,
|
||||
showInGallery,
|
||||
page,
|
||||
perPage = 10,
|
||||
}: {
|
||||
/**
|
||||
* The type of images to list
|
||||
*/
|
||||
imageType: ImageType,
|
||||
imageType?: ImageType,
|
||||
/**
|
||||
* The kind of images to list
|
||||
*/
|
||||
imageCategory: ImageCategory,
|
||||
imageCategory?: ImageCategory,
|
||||
/**
|
||||
* The kind of images to list
|
||||
* Whether to list intermediate images
|
||||
*/
|
||||
isIntermediate?: boolean,
|
||||
/**
|
||||
* The page of image metadata to get
|
||||
* Whether to list images that show in the gallery
|
||||
*/
|
||||
showInGallery?: boolean,
|
||||
/**
|
||||
* The page of images to get
|
||||
*/
|
||||
page?: number,
|
||||
/**
|
||||
* The number of image metadata per page
|
||||
* The number of images per page
|
||||
*/
|
||||
perPage?: number,
|
||||
}): CancelablePromise<PaginatedResults_ImageDTO_> {
|
||||
@ -56,6 +61,7 @@ export class ImagesService {
|
||||
'image_type': imageType,
|
||||
'image_category': imageCategory,
|
||||
'is_intermediate': isIntermediate,
|
||||
'show_in_gallery': showInGallery,
|
||||
'page': page,
|
||||
'per_page': perPage,
|
||||
},
|
||||
@ -72,20 +78,25 @@ export class ImagesService {
|
||||
* @throws ApiError
|
||||
*/
|
||||
public static uploadImage({
|
||||
formData,
|
||||
imageCategory,
|
||||
isIntermediate = false,
|
||||
isIntermediate,
|
||||
showInGallery,
|
||||
formData,
|
||||
sessionId,
|
||||
}: {
|
||||
formData: Body_upload_image,
|
||||
/**
|
||||
* The category of the image
|
||||
*/
|
||||
imageCategory?: ImageCategory,
|
||||
imageCategory: ImageCategory,
|
||||
/**
|
||||
* Whether this is an intermediate image
|
||||
*/
|
||||
isIntermediate?: boolean,
|
||||
isIntermediate: boolean,
|
||||
/**
|
||||
* Whether this image should be shown in the gallery
|
||||
*/
|
||||
showInGallery: boolean,
|
||||
formData: Body_upload_image,
|
||||
/**
|
||||
* The session ID associated with this upload, if any
|
||||
*/
|
||||
@ -97,6 +108,7 @@ export class ImagesService {
|
||||
query: {
|
||||
'image_category': imageCategory,
|
||||
'is_intermediate': isIntermediate,
|
||||
'show_in_gallery': showInGallery,
|
||||
'session_id': sessionId,
|
||||
},
|
||||
formData: formData,
|
||||
|
@ -9,7 +9,7 @@ type ReceivedResultImagesPageThunkConfig = {
|
||||
};
|
||||
};
|
||||
|
||||
export const receivedResultImagesPage = createAppAsyncThunk<
|
||||
export const receivedGalleryImages = createAppAsyncThunk<
|
||||
PaginatedResults_ImageDTO_,
|
||||
void,
|
||||
ReceivedResultImagesPageThunkConfig
|
||||
@ -23,9 +23,8 @@ export const receivedResultImagesPage = createAppAsyncThunk<
|
||||
const pageOffset = Math.floor(upsertedImageCount / IMAGES_PER_PAGE);
|
||||
|
||||
const response = await ImagesService.listImagesWithMetadata({
|
||||
imageType: 'results',
|
||||
imageCategory: 'general',
|
||||
isIntermediate: false,
|
||||
showInGallery: true,
|
||||
page: nextPage + pageOffset,
|
||||
perPage: IMAGES_PER_PAGE,
|
||||
});
|
||||
@ -40,7 +39,7 @@ type ReceivedUploadImagesPageThunkConfig = {
|
||||
};
|
||||
};
|
||||
|
||||
export const receivedUploadImagesPage = createAppAsyncThunk<
|
||||
export const receivedUploadImages = createAppAsyncThunk<
|
||||
PaginatedResults_ImageDTO_,
|
||||
void,
|
||||
ReceivedUploadImagesPageThunkConfig
|
||||
@ -55,8 +54,8 @@ export const receivedUploadImagesPage = createAppAsyncThunk<
|
||||
|
||||
const response = await ImagesService.listImagesWithMetadata({
|
||||
imageType: 'uploads',
|
||||
imageCategory: 'general',
|
||||
isIntermediate: false,
|
||||
showInGallery: false,
|
||||
page: nextPage + pageOffset,
|
||||
perPage: IMAGES_PER_PAGE,
|
||||
});
|
||||
|
@ -32,11 +32,7 @@ export const imageMetadataReceived = createAppAsyncThunk(
|
||||
}
|
||||
);
|
||||
|
||||
type ImageUploadedArg = Parameters<(typeof ImagesService)['uploadImage']>[0] & {
|
||||
// extra arg to determine post-upload actions - we check for this when the image is uploaded
|
||||
// to determine if we should set the init image
|
||||
activeTabName?: InvokeTabName;
|
||||
};
|
||||
type ImageUploadedArg = Parameters<(typeof ImagesService)['uploadImage']>[0];
|
||||
|
||||
/**
|
||||
* `ImagesService.uploadImage()` thunk
|
||||
@ -45,8 +41,7 @@ export const imageUploaded = createAppAsyncThunk(
|
||||
'api/imageUploaded',
|
||||
async (arg: ImageUploadedArg) => {
|
||||
// strip out `activeTabName` from arg - the route does not need it
|
||||
const { activeTabName, ...rest } = arg;
|
||||
const response = await ImagesService.uploadImage(rest);
|
||||
const response = await ImagesService.uploadImage(arg);
|
||||
return response;
|
||||
}
|
||||
);
|
||||
|
Loading…
Reference in New Issue
Block a user