feat(nodes, ui): wip image types

This commit is contained in:
psychedelicious 2023-05-27 18:32:16 +10:00 committed by Kent Keirsey
parent bdab73701f
commit 9317b42e5f
22 changed files with 218 additions and 151 deletions

View File

@ -34,11 +34,10 @@ async def upload_image(
file: UploadFile,
request: Request,
response: Response,
image_category: ImageCategory = Query(
default=ImageCategory.GENERAL, description="The category of the image"
),
is_intermediate: bool = Query(
default=False, description="Whether this is an intermediate image"
image_category: ImageCategory = Query(description="The category of the image"),
is_intermediate: bool = Query(description="Whether this is an intermediate image"),
show_in_gallery: bool = Query(
description="Whether this image should be shown in the gallery"
),
session_id: Optional[str] = Query(
default=None, description="The session ID associated with this upload, if any"
@ -63,6 +62,7 @@ async def upload_image(
image_category=image_category,
session_id=session_id,
is_intermediate=is_intermediate,
show_in_gallery=show_in_gallery,
)
response.status_code = 201
@ -228,24 +228,30 @@ async def get_image_urls(
response_model=PaginatedResults[ImageDTO],
)
async def list_images_with_metadata(
image_type: ImageType = Query(description="The type of images to list"),
image_category: ImageCategory = Query(description="The kind of images to list"),
is_intermediate: bool = Query(
default=False, description="Whether to list intermediate images"
image_type: Optional[ImageType] = Query(
default=None, description="The type of images to list"
),
page: int = Query(default=0, description="The page of image metadata to get"),
per_page: int = Query(
default=10, description="The number of image metadata per page"
image_category: Optional[ImageCategory] = Query(
default=None, description="The kind of images to list"
),
is_intermediate: Optional[bool] = Query(
default=None, description="Whether to list intermediate images"
),
show_in_gallery: Optional[bool] = Query(
default=None, description="Whether to list images that show in the gallery"
),
page: int = Query(default=0, description="The page of images to get"),
per_page: int = Query(default=10, description="The number of images per page"),
) -> PaginatedResults[ImageDTO]:
"""Gets a list of images with metadata"""
"""Gets a list of images"""
image_dtos = ApiDependencies.invoker.services.images.get_many(
page,
per_page,
image_type,
image_category,
is_intermediate,
page,
per_page,
show_in_gallery,
)
return image_dtos

View File

@ -63,11 +63,12 @@ class ImageRecordStorageBase(ABC):
@abstractmethod
def get_many(
self,
image_type: ImageType,
image_category: ImageCategory,
is_intermediate: bool = False,
page: int = 0,
per_page: int = 10,
image_type: Optional[ImageType] = None,
image_category: Optional[ImageCategory] = None,
is_intermediate: Optional[bool] = None,
show_in_gallery: Optional[bool] = None,
) -> PaginatedResults[ImageRecord]:
"""Gets a page of image records."""
pass
@ -91,6 +92,7 @@ class ImageRecordStorageBase(ABC):
node_id: Optional[str],
metadata: Optional[ImageMetadata],
is_intermediate: bool = False,
show_in_gallery: bool = True,
) -> datetime:
"""Saves an image record."""
pass
@ -137,6 +139,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
session_id TEXT,
node_id TEXT,
metadata TEXT,
show_in_gallery BOOLEAN DEFAULT TRUE,
is_intermediate BOOLEAN DEFAULT FALSE,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
-- Updated via trigger
@ -224,7 +227,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
""",
(changes.image_category, image_name),
)
# Change the session associated with the image
if changes.session_id is not None:
self._cursor.execute(
@ -244,36 +247,72 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
def get_many(
self,
image_type: ImageType,
image_category: ImageCategory,
is_intermediate: bool = False,
page: int = 0,
per_page: int = 10,
image_type: Optional[ImageType] = None,
image_category: Optional[ImageCategory] = None,
is_intermediate: Optional[bool] = None,
show_in_gallery: Optional[bool] = None,
) -> PaginatedResults[ImageRecord]:
try:
self._lock.acquire()
self._cursor.execute(
f"""--sql
SELECT * FROM images
WHERE image_type = ? AND image_category = ? AND is_intermediate = ?
ORDER BY created_at DESC
LIMIT ? OFFSET ?;
""",
(image_type.value, image_category.value, is_intermediate, per_page, page * per_page),
)
# Manually build two queries - one for the count, one for the records
count_query = """--sql
SELECT COUNT(*) FROM images WHERE 1=1
"""
images_query = """--sql
SELECT * FROM images WHERE 1=1
"""
query_conditions = ""
query_params = []
if image_type is not None:
query_conditions += """--sql
AND image_type = ?
"""
query_params.append(image_type.value)
if image_category is not None:
query_conditions += """--sql
AND image_category = ?
"""
query_params.append(image_category.value)
if is_intermediate is not None:
query_conditions += """--sql
AND is_intermediate = ?
"""
query_params.append(is_intermediate)
if show_in_gallery is not None:
query_conditions += """--sql
AND show_in_gallery = ?
"""
query_params.append(show_in_gallery)
query_pagination = """--sql
ORDER BY created_at DESC LIMIT ? OFFSET ?
"""
count_query += query_conditions + ";"
count_params = query_params.copy()
images_query += query_conditions + query_pagination + ";"
images_params = query_params.copy()
images_params.append(per_page)
images_params.append(page * per_page)
self._cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
self._cursor.execute(
"""--sql
SELECT count(*) FROM images
WHERE image_type = ? AND image_category = ? AND is_intermediate = ?
""",
(image_type.value, image_category.value, is_intermediate),
)
self._cursor.execute(count_query, count_params)
count = self._cursor.fetchone()[0]
except sqlite3.Error as e:
@ -316,6 +355,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id: Optional[str],
metadata: Optional[ImageMetadata],
is_intermediate: bool = False,
show_in_gallery: bool = True,
) -> datetime:
try:
metadata_json = (
@ -333,9 +373,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id,
session_id,
metadata,
is_intermediate
is_intermediate,
show_in_gallery
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
""",
(
image_name,
@ -347,6 +388,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
session_id,
metadata_json,
is_intermediate,
show_in_gallery,
),
)
self._conn.commit()

View File

@ -1,8 +1,6 @@
from abc import ABC, abstractmethod
from logging import Logger
from os import name
from typing import Optional, TYPE_CHECKING, Union
import uuid
from PIL.Image import Image as PILImageType
from invokeai.app.models.image import (
@ -51,6 +49,7 @@ class ImageServiceABC(ABC):
node_id: Optional[str] = None,
session_id: Optional[str] = None,
intermediate: bool = False,
show_in_gallery: bool = True,
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
pass
@ -100,10 +99,12 @@ class ImageServiceABC(ABC):
@abstractmethod
def get_many(
self,
image_type: ImageType,
image_category: ImageCategory,
page: int = 0,
per_page: int = 10,
image_type: Optional[ImageType] = None,
image_category: Optional[ImageCategory] = None,
is_intermediate: Optional[bool] = None,
show_in_gallery: Optional[bool] = None,
) -> PaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs."""
pass
@ -175,6 +176,7 @@ class ImageService(ImageServiceABC):
node_id: Optional[str] = None,
session_id: Optional[str] = None,
is_intermediate: bool = False,
show_in_gallery: bool = True,
) -> ImageDTO:
if image_type not in ImageType:
raise InvalidImageTypeException
@ -199,6 +201,7 @@ class ImageService(ImageServiceABC):
height=height,
# Meta fields
is_intermediate=is_intermediate,
show_in_gallery=show_in_gallery,
# Nullable fields
node_id=node_id,
session_id=session_id,
@ -233,6 +236,7 @@ class ImageService(ImageServiceABC):
updated_at=created_at, # this is always the same as the created_at at this time
deleted_at=None,
is_intermediate=is_intermediate,
show_in_gallery=show_in_gallery,
# Extra non-nullable fields for DTO
image_url=image_url,
thumbnail_url=thumbnail_url,
@ -328,28 +332,30 @@ class ImageService(ImageServiceABC):
def get_many(
self,
image_type: ImageType,
image_category: ImageCategory,
is_intermediate: bool = False,
page: int = 0,
per_page: int = 10,
image_type: Optional[ImageType] = None,
image_category: Optional[ImageCategory] = None,
is_intermediate: Optional[bool] = None,
show_in_gallery: Optional[bool] = None,
) -> PaginatedResults[ImageDTO]:
try:
results = self._services.records.get_many(
page,
per_page,
image_type,
image_category,
is_intermediate,
page,
per_page,
show_in_gallery,
)
image_dtos = list(
map(
lambda r: image_record_to_dto(
r,
self._services.urls.get_image_url(image_type, r.image_name),
self._services.urls.get_image_url(r.image_type, r.image_name),
self._services.urls.get_image_url(
image_type, r.image_name, True
r.image_type, r.image_name, True
),
),
results.items,

View File

@ -33,6 +33,8 @@ class ImageRecord(BaseModel):
"""The deleted timestamp of the image."""
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
"""Whether this is an intermediate image."""
show_in_gallery: bool = Field(description="Whether this image should be shown in the gallery.")
"""Whether this image should be shown in the gallery."""
session_id: Optional[str] = Field(
default=None,
description="The session ID that generated this image, if it is a generated image.",
@ -117,6 +119,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
updated_at = image_dict.get("updated_at", get_iso_timestamp())
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
is_intermediate = image_dict.get("is_intermediate", False)
show_in_gallery = image_dict.get("show_in_gallery", True)
raw_metadata = image_dict.get("metadata")
@ -138,4 +141,5 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
updated_at=updated_at,
deleted_at=deleted_at,
is_intermediate=is_intermediate,
show_in_gallery=show_in_gallery,
)

View File

@ -60,13 +60,13 @@ import {
addSessionCanceledRejectedListener,
} from './listeners/sessionCanceled';
import {
addReceivedResultImagesPageFulfilledListener,
addReceivedResultImagesPageRejectedListener,
} from './listeners/receivedResultImagesPage';
addReceivedGalleryImagesFulfilledListener,
addReceivedGalleryImagesRejectedListener,
} from './listeners/receivedGalleryImages';
import {
addReceivedUploadImagesPageFulfilledListener,
addReceivedUploadImagesPageRejectedListener,
} from './listeners/receivedUploadImagesPage';
} from './listeners/receivedUploadImages';
export const listenerMiddleware = createListenerMiddleware();
@ -146,7 +146,7 @@ addSessionCanceledFulfilledListener();
addSessionCanceledRejectedListener();
// Gallery pages
addReceivedResultImagesPageFulfilledListener();
addReceivedResultImagesPageRejectedListener();
addReceivedGalleryImagesFulfilledListener();
addReceivedGalleryImagesRejectedListener();
addReceivedUploadImagesPageFulfilledListener();
addReceivedUploadImagesPageRejectedListener();

View File

@ -55,6 +55,9 @@ export const addCanvasMergedListener = () => {
formData: {
file: new File([blob], filename, { type: 'image/png' }),
},
imageCategory: 'general',
isIntermediate: true,
showInGallery: false,
})
);

View File

@ -4,13 +4,15 @@ import { log } from 'app/logging/useLogger';
import { imageUploaded } from 'services/thunks/image';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice';
import { v4 as uuidv4 } from 'uuid';
import { resultUpserted } from 'features/gallery/store/resultsSlice';
const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' });
export const addCanvasSavedToGalleryListener = () => {
startAppListening({
actionCreator: canvasSavedToGallery,
effect: async (action, { dispatch, getState }) => {
effect: async (action, { dispatch, getState, take }) => {
const state = getState();
const blob = await getBaseLayerBlob(state);
@ -27,13 +29,26 @@ export const addCanvasSavedToGalleryListener = () => {
return;
}
const filename = `mergedCanvas_${uuidv4()}.png`;
dispatch(
imageUploaded({
formData: {
file: new File([blob], 'mergedCanvas.png', { type: 'image/png' }),
file: new File([blob], filename, { type: 'image/png' }),
},
imageCategory: 'general',
isIntermediate: false,
showInGallery: true,
})
);
const [{ payload: uploadedImageDTO }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.meta.arg.formData.file.name === filename
);
dispatch(resultUpserted(uploadedImageDTO));
},
});
};

View File

@ -3,10 +3,7 @@ import { uploadUpserted } from 'features/gallery/store/uploadsSlice';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { imageUploaded } from 'services/thunks/image';
import { addToast } from 'features/system/store/systemSlice';
import { initialImageSelected } from 'features/parameters/store/actions';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { resultUpserted } from 'features/gallery/store/resultsSlice';
import { isResultsImageDTO, isUploadsImageDTO } from 'services/types/guards';
import { log } from 'app/logging/useLogger';
const moduleLog = log.child({ namespace: 'image' });
@ -24,7 +21,7 @@ export const addImageUploadedFulfilledListener = () => {
const state = getState();
// Handle uploads
if (isUploadsImageDTO(image)) {
if (!image.show_in_gallery && image.image_type === 'uploads') {
dispatch(uploadUpserted(image));
dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
@ -32,19 +29,11 @@ export const addImageUploadedFulfilledListener = () => {
if (state.gallery.shouldAutoSwitchToNewImages) {
dispatch(imageSelected(image));
}
if (action.meta.arg.activeTabName === 'img2img') {
dispatch(initialImageSelected(image));
}
if (action.meta.arg.activeTabName === 'unifiedCanvas') {
dispatch(setInitialCanvasImage(image));
}
}
// Handle results
// TODO: Can this ever happen? I don't think so...
if (isResultsImageDTO(image)) {
if (image.show_in_gallery) {
dispatch(resultUpserted(image));
}
},

View File

@ -1,31 +1,31 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { receivedResultImagesPage } from 'services/thunks/gallery';
import { receivedGalleryImages } from 'services/thunks/gallery';
import { serializeError } from 'serialize-error';
const moduleLog = log.child({ namespace: 'gallery' });
export const addReceivedResultImagesPageFulfilledListener = () => {
export const addReceivedGalleryImagesFulfilledListener = () => {
startAppListening({
actionCreator: receivedResultImagesPage.fulfilled,
actionCreator: receivedGalleryImages.fulfilled,
effect: (action, { getState, dispatch }) => {
const page = action.payload;
moduleLog.debug(
{ data: { page } },
`Received ${page.items.length} results`
`Received ${page.items.length} gallery images`
);
},
});
};
export const addReceivedResultImagesPageRejectedListener = () => {
export const addReceivedGalleryImagesRejectedListener = () => {
startAppListening({
actionCreator: receivedResultImagesPage.rejected,
actionCreator: receivedGalleryImages.rejected,
effect: (action, { getState, dispatch }) => {
if (action.payload) {
moduleLog.debug(
{ data: { error: serializeError(action.payload.error) } },
'Problem receiving results'
'Problem receiving gallery images'
);
}
},

View File

@ -1,18 +1,18 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { receivedUploadImagesPage } from 'services/thunks/gallery';
import { receivedUploadImages } from 'services/thunks/gallery';
import { serializeError } from 'serialize-error';
const moduleLog = log.child({ namespace: 'gallery' });
export const addReceivedUploadImagesPageFulfilledListener = () => {
startAppListening({
actionCreator: receivedUploadImagesPage.fulfilled,
actionCreator: receivedUploadImages.fulfilled,
effect: (action, { getState, dispatch }) => {
const page = action.payload;
moduleLog.debug(
{ data: { page } },
`Received ${page.items.length} uploads`
`Received ${page.items.length} uploaded images`
);
},
});
@ -20,12 +20,12 @@ export const addReceivedUploadImagesPageFulfilledListener = () => {
export const addReceivedUploadImagesPageRejectedListener = () => {
startAppListening({
actionCreator: receivedUploadImagesPage.rejected,
actionCreator: receivedUploadImages.rejected,
effect: (action, { getState, dispatch }) => {
if (action.payload) {
moduleLog.debug(
{ data: { error: serializeError(action.payload.error) } },
'Problem receiving uploads'
'Problem receiving uploaded images'
);
}
},

View File

@ -2,8 +2,8 @@ import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import { socketConnected } from 'services/events/actions';
import {
receivedResultImagesPage,
receivedUploadImagesPage,
receivedGalleryImages,
receivedUploadImages,
} from 'services/thunks/gallery';
import { receivedModels } from 'services/thunks/model';
import { receivedOpenAPISchema } from 'services/thunks/schema';
@ -24,11 +24,11 @@ export const addSocketConnectedListener = () => {
// These thunks need to be dispatch in middleware; cannot handle in a reducer
if (!results.ids.length) {
dispatch(receivedResultImagesPage());
dispatch(receivedGalleryImages());
}
if (!uploads.ids.length) {
dispatch(receivedUploadImagesPage());
dispatch(receivedUploadImages());
}
if (!models.ids.length) {

View File

@ -101,7 +101,9 @@ export const addUserInvokedCanvasListener = () => {
formData: {
file: new File([baseBlob], baseFilename, { type: 'image/png' }),
},
imageCategory: 'general',
isIntermediate: true,
showInGallery: false,
})
);
@ -127,7 +129,9 @@ export const addUserInvokedCanvasListener = () => {
formData: {
file: new File([maskBlob], maskFilename, { type: 'image/png' }),
},
imageCategory: 'mask',
isIntermediate: true,
showInGallery: false,
})
);

View File

@ -4,7 +4,6 @@ import { useHotkeys } from 'react-hotkeys-hook';
type ImageUploadOverlayProps = {
isDragAccept: boolean;
isDragReject: boolean;
overlaySecondaryText: string;
setIsHandlingUpload: (isHandlingUpload: boolean) => void;
};
@ -12,7 +11,6 @@ const ImageUploadOverlay = (props: ImageUploadOverlayProps) => {
const {
isDragAccept,
isDragReject: _isDragAccept,
overlaySecondaryText,
setIsHandlingUpload,
} = props;
@ -48,7 +46,7 @@ const ImageUploadOverlay = (props: ImageUploadOverlayProps) => {
}}
>
{isDragAccept ? (
<Heading size="lg">Upload Image{overlaySecondaryText}</Heading>
<Heading size="lg">Drop to Upload</Heading>
) : (
<>
<Heading size="lg">Invalid Upload</Heading>

View File

@ -69,11 +69,13 @@ const ImageUploader = (props: ImageUploaderProps) => {
dispatch(
imageUploaded({
formData: { file },
activeTabName,
imageCategory: 'general',
isIntermediate: false,
showInGallery: false,
})
);
},
[dispatch, activeTabName]
[dispatch]
);
const onDrop = useCallback(
@ -144,14 +146,6 @@ const ImageUploader = (props: ImageUploaderProps) => {
};
}, [inputRef, open, setOpenUploaderFunction]);
const overlaySecondaryText = useMemo(() => {
if (['img2img', 'unifiedCanvas'].includes(activeTabName)) {
return ` to ${String(t(`common.${activeTabName}` as ResourceKey))}`;
}
return '';
}, [t, activeTabName]);
return (
<Box
{...getRootProps({ style: {} })}
@ -166,7 +160,6 @@ const ImageUploader = (props: ImageUploaderProps) => {
<ImageUploadOverlay
isDragAccept={isDragAccept}
isDragReject={isDragReject}
overlaySecondaryText={overlaySecondaryText}
setIsHandlingUpload={setIsHandlingUpload}
/>
)}

View File

@ -43,8 +43,8 @@ import HoverableImage from './HoverableImage';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { resultsAdapter } from '../store/resultsSlice';
import {
receivedResultImagesPage,
receivedUploadImagesPage,
receivedGalleryImages,
receivedUploadImages,
} from 'services/thunks/gallery';
import { uploadsAdapter } from '../store/uploadsSlice';
import { createSelector } from '@reduxjs/toolkit';
@ -151,11 +151,11 @@ const ImageGalleryContent = () => {
const handleClickLoadMore = () => {
if (currentCategory === 'results') {
dispatch(receivedResultImagesPage());
dispatch(receivedGalleryImages());
}
if (currentCategory === 'uploads') {
dispatch(receivedUploadImagesPage());
dispatch(receivedUploadImages());
}
};
@ -211,9 +211,9 @@ const ImageGalleryContent = () => {
const handleEndReached = useCallback(() => {
if (currentCategory === 'results') {
dispatch(receivedResultImagesPage());
dispatch(receivedGalleryImages());
} else if (currentCategory === 'uploads') {
dispatch(receivedUploadImagesPage());
dispatch(receivedUploadImages());
}
}, [dispatch, currentCategory]);

View File

@ -1,8 +1,8 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import {
receivedResultImagesPage,
receivedUploadImagesPage,
receivedGalleryImages,
receivedUploadImages,
} from '../../../services/thunks/gallery';
import { ImageDTO } from 'services/api';
@ -60,7 +60,7 @@ export const gallerySlice = createSlice({
},
},
extraReducers(builder) {
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
builder.addCase(receivedGalleryImages.fulfilled, (state, action) => {
// rehydrate selectedImage URL when results list comes in
// solves case when outdated URL is in local storage
const selectedImage = state.selectedImage;
@ -76,7 +76,7 @@ export const gallerySlice = createSlice({
}
}
});
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
builder.addCase(receivedUploadImages.fulfilled, (state, action) => {
// rehydrate selectedImage URL when results list comes in
// solves case when outdated URL is in local storage
const selectedImage = state.selectedImage;

View File

@ -5,7 +5,7 @@ import {
} from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import {
receivedResultImagesPage,
receivedGalleryImages,
IMAGES_PER_PAGE,
} from 'services/thunks/gallery';
import { ImageDTO } from 'services/api';
@ -15,7 +15,7 @@ export type ResultsImageDTO = Omit<ImageDTO, 'image_type'> & {
image_type: 'results';
};
export const resultsAdapter = createEntityAdapter<ResultsImageDTO>({
export const resultsAdapter = createEntityAdapter<ImageDTO>({
selectId: (image) => image.image_name,
sortComparer: (a, b) => dateComparator(b.created_at, a.created_at),
});
@ -43,7 +43,7 @@ const resultsSlice = createSlice({
name: 'results',
initialState: initialResultsState,
reducers: {
resultUpserted: (state, action: PayloadAction<ResultsImageDTO>) => {
resultUpserted: (state, action: PayloadAction<ImageDTO>) => {
resultsAdapter.upsertOne(state, action.payload);
state.upsertedImageCount += 1;
},
@ -52,18 +52,18 @@ const resultsSlice = createSlice({
/**
* Received Result Images Page - PENDING
*/
builder.addCase(receivedResultImagesPage.pending, (state) => {
builder.addCase(receivedGalleryImages.pending, (state) => {
state.isLoading = true;
});
/**
* Received Result Images Page - FULFILLED
*/
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
builder.addCase(receivedGalleryImages.fulfilled, (state, action) => {
const { page, pages } = action.payload;
// We know these will all be of the results type, but it's not represented in the API types
const items = action.payload.items as ResultsImageDTO[];
const items = action.payload.items;
resultsAdapter.setMany(state, items);

View File

@ -5,10 +5,7 @@ import {
} from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import {
receivedUploadImagesPage,
IMAGES_PER_PAGE,
} from 'services/thunks/gallery';
import { receivedUploadImages, IMAGES_PER_PAGE } from 'services/thunks/gallery';
import { ImageDTO } from 'services/api';
import { dateComparator } from 'common/util/dateComparator';
@ -16,7 +13,7 @@ export type UploadsImageDTO = Omit<ImageDTO, 'image_type'> & {
image_type: 'uploads';
};
export const uploadsAdapter = createEntityAdapter<UploadsImageDTO>({
export const uploadsAdapter = createEntityAdapter<ImageDTO>({
selectId: (image) => image.image_name,
sortComparer: (a, b) => dateComparator(b.created_at, a.created_at),
});
@ -44,7 +41,7 @@ const uploadsSlice = createSlice({
name: 'uploads',
initialState: initialUploadsState,
reducers: {
uploadUpserted: (state, action: PayloadAction<UploadsImageDTO>) => {
uploadUpserted: (state, action: PayloadAction<ImageDTO>) => {
uploadsAdapter.upsertOne(state, action.payload);
state.upsertedImageCount += 1;
},
@ -53,18 +50,18 @@ const uploadsSlice = createSlice({
/**
* Received Upload Images Page - PENDING
*/
builder.addCase(receivedUploadImagesPage.pending, (state) => {
builder.addCase(receivedUploadImages.pending, (state) => {
state.isLoading = true;
});
/**
* Received Upload Images Page - FULFILLED
*/
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
builder.addCase(receivedUploadImages.fulfilled, (state, action) => {
const { page, pages } = action.payload;
// We know these will all be of the uploads type, but it's not represented in the API types
const items = action.payload.items as UploadsImageDTO[];
const items = action.payload.items;
uploadsAdapter.setMany(state, items);

View File

@ -54,6 +54,10 @@ export type ImageDTO = {
* Whether this is an intermediate image.
*/
is_intermediate: boolean;
/**
* Whether this image should be shown in the gallery.
*/
show_in_gallery: boolean;
/**
* The session ID that generated this image, if it is a generated image.
*/

View File

@ -17,35 +17,40 @@ export class ImagesService {
/**
* List Images With Metadata
* Gets a list of images with metadata
* Gets a list of images
* @returns PaginatedResults_ImageDTO_ Successful Response
* @throws ApiError
*/
public static listImagesWithMetadata({
imageType,
imageCategory,
isIntermediate = false,
isIntermediate,
showInGallery,
page,
perPage = 10,
}: {
/**
* The type of images to list
*/
imageType: ImageType,
imageType?: ImageType,
/**
* The kind of images to list
*/
imageCategory: ImageCategory,
imageCategory?: ImageCategory,
/**
* The kind of images to list
* Whether to list intermediate images
*/
isIntermediate?: boolean,
/**
* The page of image metadata to get
* Whether to list images that show in the gallery
*/
showInGallery?: boolean,
/**
* The page of images to get
*/
page?: number,
/**
* The number of image metadata per page
* The number of images per page
*/
perPage?: number,
}): CancelablePromise<PaginatedResults_ImageDTO_> {
@ -56,6 +61,7 @@ export class ImagesService {
'image_type': imageType,
'image_category': imageCategory,
'is_intermediate': isIntermediate,
'show_in_gallery': showInGallery,
'page': page,
'per_page': perPage,
},
@ -72,20 +78,25 @@ export class ImagesService {
* @throws ApiError
*/
public static uploadImage({
formData,
imageCategory,
isIntermediate = false,
isIntermediate,
showInGallery,
formData,
sessionId,
}: {
formData: Body_upload_image,
/**
* The category of the image
*/
imageCategory?: ImageCategory,
imageCategory: ImageCategory,
/**
* Whether this is an intermediate image
*/
isIntermediate?: boolean,
isIntermediate: boolean,
/**
* Whether this image should be shown in the gallery
*/
showInGallery: boolean,
formData: Body_upload_image,
/**
* The session ID associated with this upload, if any
*/
@ -97,6 +108,7 @@ export class ImagesService {
query: {
'image_category': imageCategory,
'is_intermediate': isIntermediate,
'show_in_gallery': showInGallery,
'session_id': sessionId,
},
formData: formData,

View File

@ -9,7 +9,7 @@ type ReceivedResultImagesPageThunkConfig = {
};
};
export const receivedResultImagesPage = createAppAsyncThunk<
export const receivedGalleryImages = createAppAsyncThunk<
PaginatedResults_ImageDTO_,
void,
ReceivedResultImagesPageThunkConfig
@ -23,9 +23,8 @@ export const receivedResultImagesPage = createAppAsyncThunk<
const pageOffset = Math.floor(upsertedImageCount / IMAGES_PER_PAGE);
const response = await ImagesService.listImagesWithMetadata({
imageType: 'results',
imageCategory: 'general',
isIntermediate: false,
showInGallery: true,
page: nextPage + pageOffset,
perPage: IMAGES_PER_PAGE,
});
@ -40,7 +39,7 @@ type ReceivedUploadImagesPageThunkConfig = {
};
};
export const receivedUploadImagesPage = createAppAsyncThunk<
export const receivedUploadImages = createAppAsyncThunk<
PaginatedResults_ImageDTO_,
void,
ReceivedUploadImagesPageThunkConfig
@ -55,8 +54,8 @@ export const receivedUploadImagesPage = createAppAsyncThunk<
const response = await ImagesService.listImagesWithMetadata({
imageType: 'uploads',
imageCategory: 'general',
isIntermediate: false,
showInGallery: false,
page: nextPage + pageOffset,
perPage: IMAGES_PER_PAGE,
});

View File

@ -32,11 +32,7 @@ export const imageMetadataReceived = createAppAsyncThunk(
}
);
type ImageUploadedArg = Parameters<(typeof ImagesService)['uploadImage']>[0] & {
// extra arg to determine post-upload actions - we check for this when the image is uploaded
// to determine if we should set the init image
activeTabName?: InvokeTabName;
};
type ImageUploadedArg = Parameters<(typeof ImagesService)['uploadImage']>[0];
/**
* `ImagesService.uploadImage()` thunk
@ -45,8 +41,7 @@ export const imageUploaded = createAppAsyncThunk(
'api/imageUploaded',
async (arg: ImageUploadedArg) => {
// strip out `activeTabName` from arg - the route does not need it
const { activeTabName, ...rest } = arg;
const response = await ImagesService.uploadImage(rest);
const response = await ImagesService.uploadImage(arg);
return response;
}
);