diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index 5adc4f5e5e..f06c324bc6 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -65,18 +65,19 @@ import { addGeneratorProgressEventListener as addGeneratorProgressListener } fro import { addGraphExecutionStateCompleteEventListener as addGraphExecutionStateCompleteListener } from './listeners/socketio/socketGraphExecutionStateComplete'; import { addInvocationCompleteEventListener as addInvocationCompleteListener } from './listeners/socketio/socketInvocationComplete'; import { addInvocationErrorEventListener as addInvocationErrorListener } from './listeners/socketio/socketInvocationError'; +import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError'; import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted'; import { addModelLoadEventListener } from './listeners/socketio/socketModelLoad'; +import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError'; import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed'; import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed'; import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved'; +import { addTabChangedListener } from './listeners/tabChanged'; import { addUpscaleRequestedListener } from './listeners/upscaleRequested'; import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas'; import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage'; import { addUserInvokedNodesListener } from './listeners/userInvokedNodes'; import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage'; -import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError'; -import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError'; export const listenerMiddleware = createListenerMiddleware(); @@ -201,3 +202,6 @@ addFirstListImagesListener(); // Ad-hoc upscale workflwo addUpscaleRequestedListener(); + +// Tab Change +addTabChangedListener(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/tabChanged.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/tabChanged.ts new file mode 100644 index 0000000000..578241573c --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/tabChanged.ts @@ -0,0 +1,56 @@ +import { modelChanged } from 'features/parameters/store/generationSlice'; +import { setActiveTab } from 'features/ui/store/uiSlice'; +import { forEach } from 'lodash-es'; +import { NON_REFINER_BASE_MODELS } from 'services/api/constants'; +import { + MainModelConfigEntity, + modelsApi, +} from 'services/api/endpoints/models'; +import { startAppListening } from '..'; + +export const addTabChangedListener = () => { + startAppListening({ + actionCreator: setActiveTab, + effect: (action, { getState, dispatch }) => { + const activeTabName = action.payload; + if (activeTabName === 'unifiedCanvas') { + // grab the models from RTK Query cache + const { data } = modelsApi.endpoints.getMainModels.select( + NON_REFINER_BASE_MODELS + )(getState()); + + if (!data) { + // no models yet, so we can't do anything + dispatch(modelChanged(null)); + return; + } + + // need to filter out all the invalid canvas models (currently, this is just sdxl) + const validCanvasModels: MainModelConfigEntity[] = []; + + forEach(data.entities, (entity) => { + if (!entity) { + return; + } + if (['sd-1', 'sd-2'].includes(entity.base_model)) { + validCanvasModels.push(entity); + } + }); + + // this could still be undefined even tho TS doesn't say so + const firstValidCanvasModel = validCanvasModels[0]; + + if (!firstValidCanvasModel) { + // uh oh, we have no models that are valid for canvas + dispatch(modelChanged(null)); + return; + } + + // only store the model name and base model in redux + const { base_model, model_name } = firstValidCanvasModel; + + dispatch(modelChanged({ base_model, model_name })); + } + }, + }); +}; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/MainModel/ParamMainModelSelect.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/MainModel/ParamMainModelSelect.tsx index 4f799dc330..d380da60bf 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/MainModel/ParamMainModelSelect.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/MainModel/ParamMainModelSelect.tsx @@ -13,6 +13,7 @@ import { modelSelected } from 'features/parameters/store/actions'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam'; import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton'; +import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { forEach } from 'lodash-es'; import { NON_REFINER_BASE_MODELS } from 'services/api/constants'; import { useGetMainModelsQuery } from 'services/api/endpoints/models'; @@ -35,6 +36,8 @@ const ParamMainModelSelect = () => { NON_REFINER_BASE_MODELS ); + const activeTabName = useAppSelector(activeTabNameSelector); + const data = useMemo(() => { if (!mainModels) { return []; @@ -43,7 +46,10 @@ const ParamMainModelSelect = () => { const data: SelectItem[] = []; forEach(mainModels.entities, (model, id) => { - if (!model) { + if ( + !model || + (activeTabName === 'unifiedCanvas' && model.base_model === 'sdxl') + ) { return; } @@ -55,7 +61,7 @@ const ParamMainModelSelect = () => { }); return data; - }, [mainModels]); + }, [mainModels, activeTabName]); // grab the full model entity from the RTK Query cache // TODO: maybe we should just store the full model entity in state?