feat(ui): use query to populate infill methods dropdown

- available infill methods is server state - remove it from client state, use the query to populate the dropdown
- add listener to ensure the selected infill method is an available one
This commit is contained in:
psychedelicious 2023-07-13 15:22:25 +10:00
parent 4d25d702a1
commit 978016ea51
5 changed files with 41 additions and 44 deletions

View File

@ -8,6 +8,7 @@ import {
import type { AppDispatch, RootState } from '../../store'; import type { AppDispatch, RootState } from '../../store';
import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener'; import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener';
import { addAppConfigReceivedListener } from './listeners/appConfigReceived';
import { addAppStartedListener } from './listeners/appStarted'; import { addAppStartedListener } from './listeners/appStarted';
import { addBoardIdSelectedListener } from './listeners/boardIdSelected'; import { addBoardIdSelectedListener } from './listeners/boardIdSelected';
import { addRequestedBoardImageDeletionListener } from './listeners/boardImagesDeleted'; import { addRequestedBoardImageDeletionListener } from './listeners/boardImagesDeleted';
@ -226,3 +227,4 @@ addModelSelectedListener();
// app startup // app startup
addAppStartedListener(); addAppStartedListener();
addModelsLoadedListener(); addModelsLoadedListener();
addAppConfigReceivedListener();

View File

@ -0,0 +1,17 @@
import { setInfillMethod } from 'features/parameters/store/generationSlice';
import { appInfoApi } from 'services/api/endpoints/appInfo';
import { startAppListening } from '..';
export const addAppConfigReceivedListener = () => {
startAppListening({
matcher: appInfoApi.endpoints.getAppConfig.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
const { infill_methods } = action.payload;
const infillMethod = getState().generation.infillMethod;
if (!infill_methods.includes(infillMethod)) {
dispatch(setInfillMethod(infill_methods[0]));
}
},
});
};

View File

@ -1,25 +1,21 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { setInfillMethod } from 'features/parameters/store/generationSlice'; import { setInfillMethod } from 'features/parameters/store/generationSlice';
import { systemSelector } from 'features/system/store/systemSelectors';
import { memo, useCallback, useEffect, useMemo } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useGetAppConfigQuery } from '../../../../../../services/api/endpoints/appInfo'; import { useGetAppConfigQuery } from 'services/api/endpoints/appInfo';
import { setAvailableInfillMethods } from '../../../../../system/store/systemSlice';
const selector = createSelector( const selector = createSelector(
[generationSelector, systemSelector], [stateSelector],
(parameters, system) => { ({ generation }) => {
const { infillMethod } = parameters; const { infillMethod } = generation;
const { infillMethods } = system;
return { return {
infillMethod, infillMethod,
infillMethods,
}; };
}, },
defaultSelectorOptions defaultSelectorOptions
@ -27,9 +23,11 @@ const selector = createSelector(
const ParamInfillMethod = () => { const ParamInfillMethod = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { infillMethod, infillMethods } = useAppSelector(selector); const { infillMethod } = useAppSelector(selector);
const { data: appConfigData } = useGetAppConfigQuery(); const { data: appConfigData, isLoading } = useGetAppConfigQuery();
const infill_methods = appConfigData?.infill_methods;
const { t } = useTranslation(); const { t } = useTranslation();
@ -40,24 +38,13 @@ const ParamInfillMethod = () => {
[dispatch] [dispatch]
); );
useEffect(() => {
if (!appConfigData) return;
if (!appConfigData.patchmatch_enabled) {
const filteredMethods = infillMethods.filter(
(method) => method !== 'patchmatch'
);
dispatch(setAvailableInfillMethods(filteredMethods));
dispatch(setInfillMethod(filteredMethods[0]));
} else {
dispatch(setInfillMethod('patchmatch'));
}
}, [appConfigData, infillMethods, dispatch]);
return ( return (
<IAIMantineSelect <IAIMantineSelect
disabled={infill_methods?.length === 0}
placeholder={isLoading ? 'Loading...' : undefined}
label={t('parameters.infillMethod')} label={t('parameters.infillMethod')}
value={infillMethod} value={infillMethod}
data={infillMethods} data={infill_methods ?? []}
onChange={handleChange} onChange={handleChange}
/> />
); );

View File

@ -4,8 +4,14 @@ import * as InvokeAI from 'app/types/invokeai';
import { InvokeLogLevel } from 'app/logging/useLogger'; import { InvokeLogLevel } from 'app/logging/useLogger';
import { userInvoked } from 'app/store/actions'; import { userInvoked } from 'app/store/actions';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
import { TFuncKey, t } from 'i18next'; import { TFuncKey, t } from 'i18next';
import { LogLevelName } from 'roarr'; import { LogLevelName } from 'roarr';
import { imageUploaded } from 'services/api/thunks/image';
import {
isAnySessionRejected,
sessionCanceled,
} from 'services/api/thunks/session';
import { import {
appSocketConnected, appSocketConnected,
appSocketDisconnected, appSocketDisconnected,
@ -18,19 +24,11 @@ import {
appSocketUnsubscribed, appSocketUnsubscribed,
} from 'services/events/actions'; } from 'services/events/actions';
import { ProgressImage } from 'services/events/types'; import { ProgressImage } from 'services/events/types';
import { imageUploaded } from 'services/api/thunks/image';
import {
isAnySessionRejected,
sessionCanceled,
} from 'services/api/thunks/session';
import { makeToast } from '../../../app/components/Toaster'; import { makeToast } from '../../../app/components/Toaster';
import { LANGUAGES } from '../components/LanguagePicker'; import { LANGUAGES } from '../components/LanguagePicker';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
export type CancelStrategy = 'immediate' | 'scheduled'; export type CancelStrategy = 'immediate' | 'scheduled';
export type InfillMethod = 'tile' | 'patchmatch';
export interface SystemState { export interface SystemState {
isGFPGANAvailable: boolean; isGFPGANAvailable: boolean;
isESRGANAvailable: boolean; isESRGANAvailable: boolean;
@ -87,10 +85,6 @@ export interface SystemState {
* When a session is canceled, its ID is stored here until a new session is created. * When a session is canceled, its ID is stored here until a new session is created.
*/ */
canceledSession: string; canceledSession: string;
/**
* TODO: get this from backend
*/
infillMethods: InfillMethod[];
isPersisted: boolean; isPersisted: boolean;
shouldAntialiasProgressImage: boolean; shouldAntialiasProgressImage: boolean;
language: keyof typeof LANGUAGES; language: keyof typeof LANGUAGES;
@ -128,7 +122,6 @@ export const initialSystemState: SystemState = {
shouldLogToConsole: true, shouldLogToConsole: true,
statusTranslationKey: 'common.statusDisconnected', statusTranslationKey: 'common.statusDisconnected',
canceledSession: '', canceledSession: '',
infillMethods: ['tile', 'patchmatch'],
isPersisted: false, isPersisted: false,
language: 'en', language: 'en',
isUploading: false, isUploading: false,
@ -219,9 +212,6 @@ export const systemSlice = createSlice({
progressImageSet(state, action: PayloadAction<ProgressImage | null>) { progressImageSet(state, action: PayloadAction<ProgressImage | null>) {
state.progressImage = action.payload; state.progressImage = action.payload;
}, },
setAvailableInfillMethods(state, action: PayloadAction<InfillMethod[]>) {
state.infillMethods = action.payload;
},
}, },
extraReducers(builder) { extraReducers(builder) {
/** /**
@ -454,7 +444,6 @@ export const {
shouldAntialiasProgressImageChanged, shouldAntialiasProgressImageChanged,
languageChanged, languageChanged,
progressImageSet, progressImageSet,
setAvailableInfillMethods,
} = systemSlice.actions; } = systemSlice.actions;
export default systemSlice.reducer; export default systemSlice.reducer;

View File

@ -1,5 +1,5 @@
import { api } from '..'; import { api } from '..';
import { AppVersion, AppConfig } from '../types'; import { AppConfig, AppVersion } from '../types';
export const appInfoApi = api.injectEndpoints({ export const appInfoApi = api.injectEndpoints({
endpoints: (build) => ({ endpoints: (build) => ({
@ -8,12 +8,14 @@ export const appInfoApi = api.injectEndpoints({
url: `app/version`, url: `app/version`,
method: 'GET', method: 'GET',
}), }),
keepUnusedDataFor: 86400000, // 1 day
}), }),
getAppConfig: build.query<AppConfig, void>({ getAppConfig: build.query<AppConfig, void>({
query: () => ({ query: () => ({
url: `app/config`, url: `app/config`,
method: 'GET', method: 'GET',
}), }),
keepUnusedDataFor: 86400000, // 1 day
}), }),
}), }),
}); });