feat(ui): clean up & comment results slice

This commit is contained in:
psychedelicious 2023-04-04 00:31:17 +10:00
parent 9baa8f7a6a
commit dbf6b1b68a
8 changed files with 147 additions and 136 deletions

View File

@ -33,10 +33,9 @@ import {
STATUS, STATUS,
} from 'services/apiSlice'; } from 'services/apiSlice';
import { emitUnsubscribe } from './actions'; import { emitUnsubscribe } from './actions';
import { getGalleryImages } from 'services/thunks/extra';
import { resultAdded } from 'features/gallery/store/resultsSlice'; import { resultAdded } from 'features/gallery/store/resultsSlice';
import { buildImageUrls } from 'services/util/buildImageUrls'; import { getNextResultsPage } from 'services/thunks/extra';
import { extractTimestampFromResultImageName } from 'services/util/extractTimestampFromResultImageName'; import { prepareResultImage } from 'services/util/prepareResultImage';
/** /**
* Returns an object containing listener callbacks * Returns an object containing listener callbacks
@ -54,7 +53,12 @@ const makeSocketIOListeners = (
try { try {
dispatch(setIsConnected(true)); dispatch(setIsConnected(true));
dispatch(setCurrentStatus(i18n.t('common.statusConnected'))); dispatch(setCurrentStatus(i18n.t('common.statusConnected')));
dispatch(getGalleryImages({ count: 20 }));
// fetch more results, but only if we don't already have results
// maybe we should have a different thunk for `onConnect` vs when you click 'Load More'?
if (!getState().results.ids.length) {
dispatch(getNextResultsPage());
}
} catch (e) { } catch (e) {
console.error(e); console.error(e);
} }
@ -90,15 +94,9 @@ const makeSocketIOListeners = (
try { try {
const sessionId = data.graph_execution_state_id; const sessionId = data.graph_execution_state_id;
if (data.result.type === 'image') { if (data.result.type === 'image') {
const { image_name: imageName } = data.result.image; const resultImage = prepareResultImage(data.result.image);
const { imageUrl, thumbnailUrl } = buildImageUrls(
'results',
imageName
);
const timestamp = extractTimestampFromResultImageName(imageName);
dispatch(resultAdded(resultImage));
// // need to update the type for this or figure out how to get these values // // need to update the type for this or figure out how to get these values
// dispatch( // dispatch(
// addImage({ // addImage({
@ -116,20 +114,10 @@ const makeSocketIOListeners = (
// }) // })
// ); // );
dispatch(
resultAdded({
name: imageName,
url: imageUrl,
thumbnail: thumbnailUrl,
width: 512,
height: 512,
timestamp,
})
);
dispatch( dispatch(
addLogEntry({ addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'), timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Generated: ${imageName}`, message: `Generated: ${data.result.image.image_name}`,
}) })
); );
dispatch(setIsProcessing(false)); dispatch(setIsProcessing(false));

View File

@ -94,6 +94,8 @@ const rootPersistConfig = getPersistConfig({
...galleryBlacklist, ...galleryBlacklist,
...lightboxBlacklist, ...lightboxBlacklist,
...apiBlacklist, ...apiBlacklist,
// for now, never persist the results slice
'results',
], ],
debounce: 300, debounce: 300,
}); });

View File

@ -25,7 +25,8 @@ import HoverableImage from './HoverableImage';
import Scrollable from 'features/ui/components/common/Scrollable'; import Scrollable from 'features/ui/components/common/Scrollable';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { selectResultsAll } from '../store/resultsSlice'; import { selectResultsAll, selectResultsTotal } from '../store/resultsSlice';
import { getNextResultsPage } from 'services/thunks/extra';
const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290; const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290;
@ -49,9 +50,15 @@ const ImageGalleryContent = () => {
} = useAppSelector(imageGallerySelector); } = useAppSelector(imageGallerySelector);
const allResultImages = useAppSelector(selectResultsAll); const allResultImages = useAppSelector(selectResultsAll);
const currentResultsPage = useAppSelector((state) => state.results.page);
const totalResultsPages = useAppSelector((state) => state.results.pages);
const isLoadingResults = useAppSelector((state) => state.results.isLoading);
// const handleClickLoadMore = () => {
// dispatch(requestImages(currentCategory));
// };
const handleClickLoadMore = () => { const handleClickLoadMore = () => {
dispatch(requestImages(currentCategory)); dispatch(getNextResultsPage());
}; };
const handleChangeGalleryImageMinimumWidth = (v: number) => { const handleChangeGalleryImageMinimumWidth = (v: number) => {
@ -222,10 +229,11 @@ const ImageGalleryContent = () => {
</Grid> </Grid>
<IAIButton <IAIButton
onClick={handleClickLoadMore} onClick={handleClickLoadMore}
isDisabled={!areMoreImagesAvailable} isDisabled={currentResultsPage === totalResultsPages - 1}
isLoading={isLoadingResults}
flexShrink={0} flexShrink={0}
> >
{areMoreImagesAvailable {currentResultsPage !== totalResultsPages - 1
? t('gallery.loadMore') ? t('gallery.loadMore')
: t('gallery.allImagesLoaded')} : t('gallery.allImagesLoaded')}
</IAIButton> </IAIButton>

View File

@ -1,27 +1,81 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit'; import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { ResultImage } from 'app/invokeai';
import * as InvokeAI from 'app/invokeai';
import { RootState } from 'app/store'; import { RootState } from 'app/store';
import { map } from 'lodash';
import { getNextResultsPage } from 'services/thunks/extra';
import { isImageOutput } from 'services/types/guards';
import { prepareResultImage } from 'services/util/prepareResultImage';
const resultsAdapter = createEntityAdapter<InvokeAI.ResultImage>({ // use `createEntityAdapter` to create a slice for results images
// Image IDs are just their filename // https://redux-toolkit.js.org/api/createEntityAdapter#overview
// the "Entity" is InvokeAI.ResultImage, while the "entities" are instances of that type
const resultsAdapter = createEntityAdapter<ResultImage>({
// Provide a callback to get a stable, unique identifier for each entity. This defaults to
// `(item) => item.id`, but for our result images, the `name` is the unique identifier.
selectId: (image) => image.name, selectId: (image) => image.name,
// Keep the "all IDs" array sorted based on result timestamps // Order all images by their time (in descending order)
sortComparer: (a, b) => b.timestamp - a.timestamp, sortComparer: (a, b) => b.timestamp - a.timestamp,
}); });
// This type is intersected with the Entity type to create the shape of the state
type AdditionalResultsState = {
// these are a bit misleading; they refer to sessions, not results, but we don't have a route
// to list all images directly at this time...
page: number; // current page we are on
pages: number; // the total number of pages available
isLoading: boolean; // whether we are loading more images or not, mostly a placeholder
};
const resultsSlice = createSlice({ const resultsSlice = createSlice({
name: 'results', name: 'results',
initialState: resultsAdapter.getInitialState(), initialState: resultsAdapter.getInitialState<AdditionalResultsState>({
// provide the additional initial state
page: 0,
pages: 0,
isLoading: false,
}),
reducers: { reducers: {
// Can pass adapter functions directly as case reducers. Because we're passing this // the adapter provides some helper reducers; see the docs for all of them
// as a value, `createSlice` will auto-generate the action type / creator // can use them as helper functions within a reducer, or use the function itself as a reducer
// here we just use the function itself as the reducer. we'll call this on `invocation_complete`
// to add a single result
resultAdded: resultsAdapter.addOne, resultAdded: resultsAdapter.addOne,
resultsReceived: resultsAdapter.setAll, },
extraReducers: (builder) => {
// here we can respond to a fulfilled call of the `getNextResultsPage` thunk
// because we pass in the fulfilled thunk action creator, everything is typed
builder.addCase(getNextResultsPage.pending, (state, action) => {
state.isLoading = true;
});
builder.addCase(getNextResultsPage.fulfilled, (state, action) => {
const { items, page, pages } = action.payload;
// build flattened array of results ojects, use lodash `map()` to make results object an array
const allResults = items.flatMap((session) => map(session.results));
// filter out non-image-outputs (eg latents, prompts, etc)
const imageOutputResults = allResults.filter(isImageOutput);
// map results to ResultImage objects
const resultImages = imageOutputResults.map((result) =>
prepareResultImage(result.image)
);
// use the adapter reducer to add all the results to resultsSlice state
resultsAdapter.addMany(state, resultImages);
state.page = page;
state.pages = pages;
state.isLoading = false;
});
}, },
}); });
// Can create a set of memoized selectors based on the location of this entity state // Create a set of memoized selectors based on the location of this entity state
// to be used as selectors in a `useAppSelector()` call
export const { export const {
selectAll: selectResultsAll, selectAll: selectResultsAll,
selectById: selectResultsById, selectById: selectResultsById,
@ -30,6 +84,6 @@ export const {
selectTotal: selectResultsTotal, selectTotal: selectResultsTotal,
} = resultsAdapter.getSelectors<RootState>((state) => state.results); } = resultsAdapter.getSelectors<RootState>((state) => state.results);
export const { resultAdded, resultsReceived } = resultsSlice.actions; export const { resultAdded } = resultsSlice.actions;
export default resultsSlice.reducer; export default resultsSlice.reducer;

View File

@ -1,91 +1,31 @@
import { createAppAsyncThunk } from 'app/storeUtils'; import { createAppAsyncThunk } from 'app/storeUtils';
import { map } from 'lodash';
import { SessionsService } from 'services/api'; import { SessionsService } from 'services/api';
import { isImageOutput } from 'services/types/guards';
import { buildImageUrls } from 'services/util/buildImageUrls';
import { extractTimestampFromResultImageName } from 'services/util/extractTimestampFromResultImageName';
import { resultsReceived } from 'features/gallery/store/resultsSlice';
type GetGalleryImagesArg = {
count: number;
};
/** /**
* Get the last 20 sessions' worth of images. * Get the last 10 sessions' worth of images.
* *
* This should be at most 20 images so long as we continue to make a new session for every * This should be at most 10 images so long as we continue to make a new session for every
* generation. * generation.
* *
* If a session was created but no image generated, this will be < 20 images. * If a session was created but no image generated, this will be < 10 images.
* *
* When we allow more images per sesssion, this is kinda no longer a viable way to grab results, * When we allow more images per sesssion, this is kinda no longer a viable way to grab results,
* because a session could have many, many images. In that situation, barring a change to the api, * because a session could have many, many images. In that situation, barring a change to the api,
* we have to keep track of images we've grabbed and the session they came from, so that when we * we have to keep track of images we've grabbed and the session they came from, so that when we
* want to load more, we can "resume" fetching images from that session. * want to load more, we can "resume" fetching images from that session.
*
* The API should change.
*/ */
export const getGalleryImages = createAppAsyncThunk( export const getNextResultsPage = createAppAsyncThunk(
'api/getGalleryImages', 'results/getMoreResultsImages',
async (arg: GetGalleryImagesArg, { dispatch }) => { async (_arg, { getState }) => {
const { page } = getState().results;
const response = await SessionsService.listSessions({ const response = await SessionsService.listSessions({
page: 0, page: page + 1,
perPage: 20, perPage: 10,
}); });
// build flattened array of results ojects, use lodash `map()` to make results object an array return response;
const allResults = response.items.flatMap((session) =>
map(session.results)
);
// filter out non-image-outputs (eg latents, prompts, etc)
const imageOutputResults = allResults.filter(isImageOutput);
// build ResultImage objects
const resultImages = imageOutputResults.map((result) => {
const name = result.image.image_name;
const { imageUrl, thumbnailUrl } = buildImageUrls('results', name);
const timestamp = extractTimestampFromResultImageName(name);
return {
name,
url: imageUrl,
thumbnail: thumbnailUrl,
timestamp,
height: 512,
width: 512,
};
});
// update the results slice
dispatch(resultsReceived(resultImages));
// response.items.forEach((session) => {
// forEach(session.results, (result) => {
// if (isImageOutput(result)) {
// const { imageUrl, thumbnailUrl } = buildImageUrls(
// result.image.image_type!, // fix the generated types to avoid non-null assertion
// result.image.image_name! // fix the generated types to avoid non-null assertion
// );
// dispatch
// dispatch(
// addImage({
// category: 'result',
// image: {
// uuid: uuidv4(),
// url: imageUrl,
// thumbnail: ,
// width: 512,
// height: 512,
// category: 'result',
// name: result.image.image_name,
// mtime: new Date().getTime(),
// },
// })
// );
// }
// });
// });
} }
); );

View File

@ -1,17 +0,0 @@
import { ImageType } from 'services/api';
export const buildImageUrls = (
imageType: ImageType,
imageName: string
): { imageUrl: string; thumbnailUrl: string } => {
const imageUrl = `api/v1/images/${imageType}/${imageName}`;
const thumbnailUrl = `api/v1/images/${imageType}/thumbnails/${
imageName.split('.')[0]
}.webp`;
return {
imageUrl,
thumbnailUrl,
};
};

View File

@ -1,9 +0,0 @@
export const extractTimestampFromResultImageName = (imageName: string) => {
const timestamp = imageName.split('_')?.pop()?.split('.')[0];
if (timestamp === undefined) {
return 0;
}
return Number(timestamp);
};

View File

@ -0,0 +1,45 @@
import { ResultImage } from 'app/invokeai';
import { ImageField, ImageType } from 'services/api';
export const buildImageUrls = (
imageType: ImageType,
imageName: string
): { imageUrl: string; thumbnailUrl: string } => {
const imageUrl = `api/v1/images/${imageType}/${imageName}`;
const thumbnailUrl = `api/v1/images/${imageType}/thumbnails/${
imageName.split('.')[0]
}.webp`;
return {
imageUrl,
thumbnailUrl,
};
};
export const extractTimestampFromResultImageName = (imageName: string) => {
const timestamp = imageName.split('_')?.pop()?.split('.')[0];
if (timestamp === undefined) {
return 0;
}
return Number(timestamp);
};
export const prepareResultImage = (image: ImageField): ResultImage => {
const name = image.image_name;
const { imageUrl, thumbnailUrl } = buildImageUrls('results', name);
const timestamp = extractTimestampFromResultImageName(name);
return {
name,
url: imageUrl,
thumbnail: thumbnailUrl,
timestamp,
height: 512,
width: 512,
};
};