refactor(ui): fix more types

This commit is contained in:
psychedelicious 2024-06-16 12:45:05 +10:00
parent d6bd1e4a49
commit 7ef4553fc9
11 changed files with 68 additions and 61 deletions

View File

@ -10,7 +10,9 @@ export const ReduxInit = memo((props: PropsWithChildren) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
useGlobalModifiersInit(); useGlobalModifiersInit();
useEffect(() => { useEffect(() => {
dispatch(modelChanged({ key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' })); dispatch(
modelChanged({ model: { key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' } })
);
}, []); }, []);
return props.children; return props.children;

View File

@ -1,5 +1,5 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { setInfillMethod } from 'features/canvas/store/canvasSlice'; import { setInfillMethod } from 'features/controlLayers/store/canvasV2Slice';
import { shouldUseNSFWCheckerChanged, shouldUseWatermarkerChanged } from 'features/system/store/systemSlice'; import { shouldUseNSFWCheckerChanged, shouldUseWatermarkerChanged } from 'features/system/store/systemSlice';
import { appInfoApi } from 'services/api/endpoints/appInfo'; import { appInfoApi } from 'services/api/endpoints/appInfo';
@ -8,7 +8,7 @@ export const addAppConfigReceivedListener = (startAppListening: AppStartListenin
matcher: appInfoApi.endpoints.getAppConfig.matchFulfilled, matcher: appInfoApi.endpoints.getAppConfig.matchFulfilled,
effect: async (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch }) => {
const { infill_methods = [], nsfw_methods = [], watermarking_methods = [] } = action.payload; const { infill_methods = [], nsfw_methods = [], watermarking_methods = [] } = action.payload;
const infillMethod = getState().generation.infillMethod; const infillMethod = getState().canvasV2.compositing.infillMethod;
if (!infill_methods.includes(infillMethod)) { if (!infill_methods.includes(infillMethod)) {
// if there is no infill method, set it to the first one // if there is no infill method, set it to the first one

View File

@ -2,7 +2,7 @@ import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { canvasImageToControlAdapter } from 'features/canvas/store/actions'; import { canvasImageToControlAdapter } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice'; import { caImageChanged } from 'features/controlLayers/store/canvasV2Slice';
import { toast } from 'features/toast/toast'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
@ -47,14 +47,7 @@ export const addCanvasImageToControlNetListener = (startAppListening: AppStartLi
}) })
).unwrap(); ).unwrap();
const { image_name } = imageDTO; dispatch(caImageChanged({ id, imageDTO }));
dispatch(
controlAdapterImageChanged({
id,
controlImage: image_name,
})
);
}, },
}); });
}; };

View File

@ -2,7 +2,7 @@ import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { canvasMaskToControlAdapter } from 'features/canvas/store/actions'; import { canvasMaskToControlAdapter } from 'features/canvas/store/actions';
import { getCanvasData } from 'features/canvas/util/getCanvasData'; import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice'; import { caImageChanged } from 'features/controlLayers/store/canvasV2Slice';
import { toast } from 'features/toast/toast'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
@ -57,14 +57,7 @@ export const addCanvasMaskToControlNetListener = (startAppListening: AppStartLis
}) })
).unwrap(); ).unwrap();
const { image_name } = imageDTO; dispatch(caImageChanged({ id, imageDTO }));
dispatch(
controlAdapterImageChanged({
id,
controlImage: image_name,
})
);
}, },
}); });
}; };

View File

@ -1,12 +1,8 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { import { caIsEnabledToggled, modelChanged, vaeSelected } from 'features/controlLayers/store/canvasV2Slice';
controlAdapterIsEnabledChanged,
selectControlAdapterAll,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { loraRemoved } from 'features/lora/store/loraSlice'; import { loraRemoved } from 'features/lora/store/loraSlice';
import { modelSelected } from 'features/parameters/store/actions'; import { modelSelected } from 'features/parameters/store/actions';
import { modelChanged, vaeSelected } from 'features/canvas/store/canvasSlice';
import { zParameterModel } from 'features/parameters/types/parameterSchemas'; import { zParameterModel } from 'features/parameters/types/parameterSchemas';
import { toast } from 'features/toast/toast'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
@ -51,10 +47,12 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
} }
// handle incompatible controlnets // handle incompatible controlnets
selectControlAdapterAll(state.controlAdapters).forEach((ca) => { state.canvasV2.controlAdapters.forEach((ca) => {
if (ca.model?.base !== newBaseModel) { if (ca.model?.base !== newBaseModel) {
dispatch(controlAdapterIsEnabledChanged({ id: ca.id, isEnabled: false }));
modelsCleared += 1; modelsCleared += 1;
if (ca.isEnabled) {
dispatch(caIsEnabledToggled({ id: ca.id }));
}
} }
}); });

View File

@ -5,8 +5,10 @@ import type { JSONObject } from 'common/types';
import { import {
caModelChanged, caModelChanged,
heightChanged, heightChanged,
ipaModelChanged,
modelChanged, modelChanged,
refinerModelChanged, refinerModelChanged,
rgIPAdapterModelChanged,
vaeSelected, vaeSelected,
widthChanged, widthChanged,
} from 'features/controlLayers/store/canvasV2Slice'; } from 'features/controlLayers/store/canvasV2Slice';
@ -20,6 +22,9 @@ import type { Logger } from 'roarr';
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models'; import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types'; import type { AnyModelConfig } from 'services/api/types';
import { import {
isControlNetOrT2IAdapterModelConfig,
isIPAdapterModelConfig,
isLoRAModelConfig,
isNonRefinerMainModelConfig, isNonRefinerMainModelConfig,
isRefinerMainModelModelConfig, isRefinerMainModelModelConfig,
isSpandrelImageToImageModelConfig, isSpandrelImageToImageModelConfig,
@ -44,6 +49,7 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
handleLoRAModels(models, state, dispatch, log); handleLoRAModels(models, state, dispatch, log);
handleControlAdapterModels(models, state, dispatch, log); handleControlAdapterModels(models, state, dispatch, log);
handleSpandrelImageToImageModels(models, state, dispatch, log); handleSpandrelImageToImageModels(models, state, dispatch, log);
handleIPAdapterModels(models, state, dispatch, log);
}, },
}); });
}; };
@ -75,7 +81,7 @@ const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
if (defaultModelInList) { if (defaultModelInList) {
const result = zParameterModel.safeParse(defaultModelInList); const result = zParameterModel.safeParse(defaultModelInList);
if (result.success) { if (result.success) {
dispatch(modelChanged({ model: defaultModelInList, previousModel: currentModel ?? undefined })); dispatch(modelChanged({ model: defaultModelInList, previousModel: currentModel }));
const optimalDimension = getOptimalDimension(defaultModelInList); const optimalDimension = getOptimalDimension(defaultModelInList);
if (getIsSizeOptimal(state.canvasV2.document.width, state.canvasV2.document.height, optimalDimension)) { if (getIsSizeOptimal(state.canvasV2.document.width, state.canvasV2.document.height, optimalDimension)) {
@ -99,7 +105,7 @@ const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
return; return;
} }
dispatch(modelChanged({ model: result.data, previousModel: currentModel ?? undefined })); dispatch(modelChanged({ model: result.data, previousModel: currentModel }));
}; };
const handleRefinerModels: ModelHandler = (models, state, dispatch, _log) => { const handleRefinerModels: ModelHandler = (models, state, dispatch, _log) => {
@ -156,30 +162,49 @@ const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
const handleLoRAModels: ModelHandler = (models, state, dispatch, _log) => { const handleLoRAModels: ModelHandler = (models, state, dispatch, _log) => {
const loras = state.lora.loras; const loras = state.lora.loras;
const loraModels = models.filter(isLoRAModelConfig);
forEach(loras, (lora, id) => { forEach(loras, (lora, id) => {
const isLoRAAvailable = models.some((m) => m.key === lora.model.key); const isLoRAAvailable = loraModels.some((m) => m.key === lora.model.key);
if (isLoRAAvailable) { if (isLoRAAvailable) {
return; return;
} }
dispatch(loraRemoved(id)); dispatch(loraRemoved(id));
}); });
}; };
const handleControlAdapterModels: ModelHandler = (models, state, dispatch, _log) => { const handleControlAdapterModels: ModelHandler = (models, state, dispatch, _log) => {
const caModels = models.filter(isControlNetOrT2IAdapterModelConfig);
state.canvasV2.controlAdapters.forEach((ca) => { state.canvasV2.controlAdapters.forEach((ca) => {
const isModelAvailable = models.some((m) => m.key === ca.model?.key); const isModelAvailable = caModels.some((m) => m.key === ca.model?.key);
if (isModelAvailable) { if (isModelAvailable) {
return; return;
} }
dispatch(caModelChanged({ id: ca.id, modelConfig: null })); dispatch(caModelChanged({ id: ca.id, modelConfig: null }));
}); });
}; };
const handleIPAdapterModels: ModelHandler = (models, state, dispatch, _log) => {
const ipaModels = models.filter(isIPAdapterModelConfig);
state.canvasV2.controlAdapters.forEach(({ id, model }) => {
const isModelAvailable = ipaModels.some((m) => m.key === model?.key);
if (isModelAvailable) {
return;
}
dispatch(ipaModelChanged({ id, modelConfig: null }));
});
state.canvasV2.regions.forEach(({ id, ipAdapters }) => {
ipAdapters.forEach(({ id: ipAdapterId, model }) => {
const isModelAvailable = ipaModels.some((m) => m.key === model?.key);
if (isModelAvailable) {
return;
}
dispatch(rgIPAdapterModelChanged({ id, ipAdapterId, modelConfig: null }));
});
});
};
const handleSpandrelImageToImageModels: ModelHandler = (models, state, dispatch, _log) => { const handleSpandrelImageToImageModels: ModelHandler = (models, state, dispatch, _log) => {
const { upscaleModel: currentUpscaleModel, postProcessingModel: currentPostProcessingModel } = state.upscale; const { upscaleModel: currentUpscaleModel, postProcessingModel: currentPostProcessingModel } = state.upscale;
const upscaleModels = models.filter(isSpandrelImageToImageModelConfig); const upscaleModels = models.filter(isSpandrelImageToImageModelConfig);

View File

@ -1,14 +1,15 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { heightChanged, widthChanged } from 'features/controlLayers/store/canvasV2Slice';
import { setDefaultSettings } from 'features/parameters/store/actions';
import { import {
heightChanged,
setCfgRescaleMultiplier, setCfgRescaleMultiplier,
setCfgScale, setCfgScale,
setScheduler, setScheduler,
setSteps, setSteps,
vaePrecisionChanged, vaePrecisionChanged,
vaeSelected, vaeSelected,
} from 'features/canvas/store/canvasSlice'; widthChanged,
} from 'features/controlLayers/store/canvasV2Slice';
import { setDefaultSettings } from 'features/parameters/store/actions';
import { import {
isParameterCFGRescaleMultiplier, isParameterCFGRescaleMultiplier,
isParameterCFGScale, isParameterCFGScale,

View File

@ -34,7 +34,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
// This complete event has an associated image output // This complete event has an associated image output
if (data.result.type === 'image_output' && !nodeTypeDenylist.includes(data.invocation.type)) { if (data.result.type === 'image_output' && !nodeTypeDenylist.includes(data.invocation.type)) {
const { image_name } = data.result.image; const { image_name } = data.result.image;
const { canvas, gallery } = getState(); const { canvasV2, gallery } = getState();
// This populates the `getImageDTO` cache // This populates the `getImageDTO` cache
const imageDTORequest = dispatch( const imageDTORequest = dispatch(
@ -47,7 +47,9 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
imageDTORequest.unsubscribe(); imageDTORequest.unsubscribe();
// Add canvas images to the staging area // Add canvas images to the staging area
if (canvas.batchIds.includes(data.batch_id) && data.invocation_source_id === CANVAS_OUTPUT) { // TODO(psyche): canvas batchid processing, [] -> canvas.batchIds
// if (canvas.batchIds.includes(data.batch_id) && data.invocation_source_id === CANVAS_OUTPUT) {
if ([].includes(data.batch_id) && data.invocation_source_id === CANVAS_OUTPUT) {
dispatch(addImageToStagingArea(imageDTO)); dispatch(addImageToStagingArea(imageDTO));
} }

View File

@ -47,7 +47,10 @@ export const paramsReducers = {
setShouldRandomizeSeed: (state, action: PayloadAction<boolean>) => { setShouldRandomizeSeed: (state, action: PayloadAction<boolean>) => {
state.params.shouldRandomizeSeed = action.payload; state.params.shouldRandomizeSeed = action.payload;
}, },
modelChanged: (state, action: PayloadAction<{ model: ParameterModel | null; previousModel?: ParameterModel }>) => { modelChanged: (
state,
action: PayloadAction<{ model: ParameterModel | null; previousModel?: ParameterModel | null }>
) => {
const { model, previousModel } = action.payload; const { model, previousModel } = action.payload;
state.params.model = model; state.params.model = model;

View File

@ -1,11 +1,6 @@
import { skipToken } from '@reduxjs/toolkit/query'; import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { iiLayerAdded } from 'features/controlLayers/store/canvasV2Slice';
import { selectOptimalDimension } from 'features/controlLayers/store/selectors';
import { parseAndRecallAllMetadata } from 'features/metadata/util/handlers'; import { parseAndRecallAllMetadata } from 'features/metadata/util/handlers';
import { toast } from 'features/toast/toast'; import { toast } from 'features/toast/toast';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { t } from 'i18next'; import { t } from 'i18next';
import { useCallback, useEffect } from 'react'; import { useCallback, useEffect } from 'react';
import { useGetImageDTOQuery, useGetImageMetadataQuery } from 'services/api/endpoints/images'; import { useGetImageDTOQuery, useGetImageMetadataQuery } from 'services/api/endpoints/images';
@ -14,31 +9,30 @@ export const usePreselectedImage = (selectedImage?: {
imageName: string; imageName: string;
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters'; action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
}) => { }) => {
const dispatch = useAppDispatch();
const optimalDimension = useAppSelector(selectOptimalDimension);
const { currentData: selectedImageDto } = useGetImageDTOQuery(selectedImage?.imageName ?? skipToken); const { currentData: selectedImageDto } = useGetImageDTOQuery(selectedImage?.imageName ?? skipToken);
const { currentData: selectedImageMetadata } = useGetImageMetadataQuery(selectedImage?.imageName ?? skipToken); const { currentData: selectedImageMetadata } = useGetImageMetadataQuery(selectedImage?.imageName ?? skipToken);
const handleSendToCanvas = useCallback(() => { const handleSendToCanvas = useCallback(() => {
if (selectedImageDto) { if (selectedImageDto) {
dispatch(setInitialCanvasImage(selectedImageDto, optimalDimension)); // TODO(psyche): handle send to canvas
dispatch(setActiveTab('canvas')); // dispatch(setInitialCanvasImage(selectedImageDto, optimalDimension));
// dispatch(setActiveTab('canvas'));
toast({ toast({
id: 'SENT_TO_CANVAS', id: 'SENT_TO_CANVAS',
title: t('toast.sentToUnifiedCanvas'), title: t('toast.sentToUnifiedCanvas'),
status: 'info', status: 'info',
}); });
} }
}, [selectedImageDto, dispatch, optimalDimension]); }, [selectedImageDto]);
const handleSendToImg2Img = useCallback(() => { const handleSendToImg2Img = useCallback(() => {
if (selectedImageDto) { if (selectedImageDto) {
dispatch(iiLayerAdded(selectedImageDto)); // TODO(psyche): handle send to img2img
dispatch(setActiveTab('generation')); // dispatch(iiLayerAdded(selectedImageDto));
// dispatch(setActiveTab('generation'));
} }
}, [dispatch, selectedImageDto]); }, [selectedImageDto]);
const handleUseAllMetadata = useCallback(() => { const handleUseAllMetadata = useCallback(() => {
if (selectedImageMetadata) { if (selectedImageMetadata) {

View File

@ -1,13 +1,9 @@
import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/query'; import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { selectGenerationSlice } from 'features/canvas/store/canvasSlice';
import { useGetModelConfigQuery } from 'services/api/endpoints/models'; import { useGetModelConfigQuery } from 'services/api/endpoints/models';
const selectModelKey = createSelector(selectGenerationSlice, (generation) => generation.model?.key);
export const useSelectedModelConfig = () => { export const useSelectedModelConfig = () => {
const key = useAppSelector(selectModelKey); const key = useAppSelector((s) => s.canvasV2.params.model?.key);
const { currentData: modelConfig } = useGetModelConfigQuery(key ?? skipToken); const { currentData: modelConfig } = useGetModelConfigQuery(key ?? skipToken);
return modelConfig; return modelConfig;