fix(ui): fix excessive re-renders

This commit is contained in:
psychedelicious 2023-10-06 16:37:47 +11:00
parent 9508e0c9db
commit 6b8ce34eb3
3 changed files with 70 additions and 81 deletions

View File

@ -4,7 +4,6 @@ import { isEqual } from 'lodash-es';
import { import {
ButtonGroup, ButtonGroup,
Flex, Flex,
FlexProps,
Menu, Menu,
MenuButton, MenuButton,
MenuList, MenuList,
@ -82,9 +81,7 @@ const currentImageButtonsSelector = createSelector(
} }
); );
type CurrentImageButtonsProps = FlexProps; const CurrentImageButtons = () => {
const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { const {
isConnected, isConnected,
@ -248,10 +245,9 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
alignItems: 'center', alignItems: 'center',
gap: 2, gap: 2,
}} }}
{...props}
> >
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}> <ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
<Menu> <Menu isLazy>
<MenuButton <MenuButton
as={IAIIconButton} as={IAIIconButton}
aria-label={t('parameters.imageActions')} aria-label={t('parameters.imageActions')}

View File

@ -1,7 +1,7 @@
import { ButtonGroup, Divider, Flex } from '@chakra-ui/react'; import { ButtonGroup, Divider, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAICollapse from 'common/components/IAICollapse'; import IAICollapse from 'common/components/IAICollapse';
@ -17,37 +17,36 @@ import {
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { Fragment, memo } from 'react'; import { Fragment, memo } from 'react';
import { FaPlus } from 'react-icons/fa'; import { FaPlus } from 'react-icons/fa';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
const selector = createSelector( const selector = createSelector(
[stateSelector], [stateSelector],
({ controlAdapters }) => { ({ controlAdapters }) => {
const activeLabel: string[] = []; const activeLabel: string[] = [];
const validIPAdapters = selectAllIPAdapters(controlAdapters); const ipAdapters = selectAllIPAdapters(controlAdapters);
const validIPAdapterCount = validIPAdapters.length; const ipAdapterCount = ipAdapters.length;
if (validIPAdapterCount > 0) { if (ipAdapterCount > 0) {
activeLabel.push(`${validIPAdapterCount} IP`); activeLabel.push(`${ipAdapterCount} IP`);
} }
const validControlNets = selectAllControlNets(controlAdapters); const controlNets = selectAllControlNets(controlAdapters);
const validControlNetCount = validControlNets.length; const controlNetCount = controlNets.length;
if (validControlNetCount > 0) { if (controlNetCount > 0) {
activeLabel.push(`${validControlNetCount} ControlNet`); activeLabel.push(`${controlNetCount} ControlNet`);
} }
const validT2IAdapters = selectAllT2IAdapters(controlAdapters); const t2iAdapters = selectAllT2IAdapters(controlAdapters);
const validT2IAdapterCount = validT2IAdapters.length; const t2iAdapterCount = t2iAdapters.length;
if (validT2IAdapterCount > 0) { if (t2iAdapterCount > 0) {
activeLabel.push(`${validT2IAdapterCount} T2I`); activeLabel.push(`${t2iAdapterCount} T2I`);
} }
const controlAdapterIds = [ipAdapters, controlNets, t2iAdapters]
.flat()
.map((ca) => ca.id);
return { return {
controlAdapters: [ controlAdapterIds,
...validIPAdapters,
...validControlNets,
...validT2IAdapters,
],
activeLabel: activeLabel.join(', '), activeLabel: activeLabel.join(', '),
}; };
}, },
@ -55,14 +54,16 @@ const selector = createSelector(
); );
const ParamControlNetCollapse = () => { const ParamControlNetCollapse = () => {
const { controlAdapters, activeLabel } = useAppSelector(selector); const { controlAdapterIds, activeLabel } = useAppSelector(selector);
const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled; const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled;
const dispatch = useAppDispatch();
const { data: controlnetModels } = useGetControlNetModelsQuery();
const { addControlNet } = useAddControlNet(); const { addControlNet } = useAddControlNet();
const { addIPAdapter } = useAddIPAdapter(); const { addIPAdapter } = useAddIPAdapter();
const { addT2IAdapter } = useAddT2IAdapter(); const { addT2IAdapter } = useAddT2IAdapter();
if (isControlNetDisabled) {
return null;
}
return ( return (
<IAICollapse label="Control Adapters" activeLabel={activeLabel}> <IAICollapse label="Control Adapters" activeLabel={activeLabel}>
<Flex sx={{ flexDir: 'column', gap: 2 }}> <Flex sx={{ flexDir: 'column', gap: 2 }}>
@ -89,10 +90,10 @@ const ParamControlNetCollapse = () => {
T2I Adapter T2I Adapter
</IAIButton> </IAIButton>
</ButtonGroup> </ButtonGroup>
{controlAdapters.map((ca, i) => ( {controlAdapterIds.map((id, i) => (
<Fragment key={ca.id}> <Fragment key={id}>
{i > 0 && <Divider />} {i > 0 && <Divider />}
<ControlNet id={ca.id} /> <ControlNet id={id} />
</Fragment> </Fragment>
))} ))}
</Flex> </Flex>

View File

@ -2,11 +2,16 @@ import { createSelector } from '@reduxjs/toolkit';
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import {
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
CONTROLNET_PROCESSORS,
} from 'features/controlNet/store/constants';
import { import {
CoreMetadata,
LoRAMetadataItem,
ControlNetMetadataItem, ControlNetMetadataItem,
CoreMetadata,
IPAdapterMetadataItem, IPAdapterMetadataItem,
LoRAMetadataItem,
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { import {
refinerModelChanged, refinerModelChanged,
@ -22,12 +27,13 @@ import {
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
import { import {
controlNetModelsAdapter, controlNetModelsAdapter,
ipAdapterModelsAdapter, ipAdapterModelsAdapter,
useGetIPAdapterModelsQuery,
loraModelsAdapter, loraModelsAdapter,
useGetControlNetModelsQuery, useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetLoRAModelsQuery, useGetLoRAModelsQuery,
} from '../../../services/api/endpoints/models'; } from '../../../services/api/endpoints/models';
import { loraRecalled, lorasCleared } from '../../lora/store/loraSlice'; import { loraRecalled, lorasCleared } from '../../lora/store/loraSlice';
@ -45,10 +51,10 @@ import {
} from '../store/generationSlice'; } from '../store/generationSlice';
import { import {
isValidCfgScale, isValidCfgScale,
isValidHeight,
isValidLoRAModel,
isValidControlNetModel, isValidControlNetModel,
isValidHeight,
isValidIPAdapterModel, isValidIPAdapterModel,
isValidLoRAModel,
isValidMainModel, isValidMainModel,
isValidNegativePrompt, isValidNegativePrompt,
isValidPositivePrompt, isValidPositivePrompt,
@ -64,22 +70,18 @@ import {
isValidStrength, isValidStrength,
isValidWidth, isValidWidth,
} from '../types/parameterSchemas'; } from '../types/parameterSchemas';
import { v4 as uuidv4 } from 'uuid';
import {
CONTROLNET_PROCESSORS,
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
} from 'features/controlNet/store/constants';
const selector = createSelector(stateSelector, ({ generation }) => { const selector = createSelector(
const { model } = generation; stateSelector,
return { model }; ({ generation }) => generation.model,
}); defaultSelectorOptions
);
export const useRecallParameters = () => { export const useRecallParameters = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const toaster = useAppToaster(); const toaster = useAppToaster();
const { t } = useTranslation(); const { t } = useTranslation();
const { model } = useAppSelector(selector); const model = useAppSelector(selector);
const parameterSetToast = useCallback(() => { const parameterSetToast = useCallback(() => {
toaster({ toaster({
@ -349,13 +351,7 @@ export const useRecallParameters = () => {
* Recall LoRA with toast * Recall LoRA with toast
*/ */
const { loras } = useGetLoRAModelsQuery(undefined, { const { data: loras } = useGetLoRAModelsQuery(undefined);
selectFromResult: (result) => ({
loras: result.data
? loraModelsAdapter.getSelectors().selectAll(result.data)
: [],
}),
});
const prepareLoRAMetadataItem = useCallback( const prepareLoRAMetadataItem = useCallback(
(loraMetadataItem: LoRAMetadataItem) => { (loraMetadataItem: LoRAMetadataItem) => {
@ -365,9 +361,11 @@ export const useRecallParameters = () => {
const { base_model, model_name } = loraMetadataItem.lora; const { base_model, model_name } = loraMetadataItem.lora;
const matchingLoRA = loras.find( const matchingLoRA = loras
(l) => l.base_model === base_model && l.model_name === model_name ? loraModelsAdapter
); .getSelectors()
.selectById(loras, `${base_model}/lora/${model_name}`)
: undefined;
if (!matchingLoRA) { if (!matchingLoRA) {
return { lora: null, error: 'LoRA model is not installed' }; return { lora: null, error: 'LoRA model is not installed' };
@ -410,13 +408,7 @@ export const useRecallParameters = () => {
* Recall ControlNet with toast * Recall ControlNet with toast
*/ */
const { controlnets } = useGetControlNetModelsQuery(undefined, { const { data: controlNets } = useGetControlNetModelsQuery(undefined);
selectFromResult: (result) => ({
controlnets: result.data
? controlNetModelsAdapter.getSelectors().selectAll(result.data)
: [],
}),
});
const prepareControlNetMetadataItem = useCallback( const prepareControlNetMetadataItem = useCallback(
(controlnetMetadataItem: ControlNetMetadataItem) => { (controlnetMetadataItem: ControlNetMetadataItem) => {
@ -434,11 +426,14 @@ export const useRecallParameters = () => {
resize_mode, resize_mode,
} = controlnetMetadataItem; } = controlnetMetadataItem;
const matchingControlNetModel = controlnets.find( const matchingControlNetModel = controlNets
(c) => ? controlNetModelsAdapter
c.base_model === control_model.base_model && .getSelectors()
c.model_name === control_model.model_name .selectById(
); controlNets,
`${control_model.base_model}/controlnet/${control_model.model_name}`
)
: undefined;
if (!matchingControlNetModel) { if (!matchingControlNetModel) {
return { controlnet: null, error: 'ControlNet model is not installed' }; return { controlnet: null, error: 'ControlNet model is not installed' };
@ -491,7 +486,7 @@ export const useRecallParameters = () => {
return { controlnet, error: null }; return { controlnet, error: null };
}, },
[controlnets, model?.base_model] [controlNets, model?.base_model]
); );
const recallControlNet = useCallback( const recallControlNet = useCallback(
@ -523,13 +518,7 @@ export const useRecallParameters = () => {
* Recall IP Adapter with toast * Recall IP Adapter with toast
*/ */
const { ipAdapters } = useGetIPAdapterModelsQuery(undefined, { const { data: ipAdapters } = useGetIPAdapterModelsQuery(undefined);
selectFromResult: (result) => ({
ipAdapters: result.data
? ipAdapterModelsAdapter.getSelectors().selectAll(result.data)
: [],
}),
});
const prepareIPAdapterMetadataItem = useCallback( const prepareIPAdapterMetadataItem = useCallback(
(ipAdapterMetadataItem: IPAdapterMetadataItem) => { (ipAdapterMetadataItem: IPAdapterMetadataItem) => {
@ -545,11 +534,14 @@ export const useRecallParameters = () => {
end_step_percent, end_step_percent,
} = ipAdapterMetadataItem; } = ipAdapterMetadataItem;
const matchingIPAdapterModel = ipAdapters.find( const matchingIPAdapterModel = ipAdapters
(c) => ? ipAdapterModelsAdapter
c.base_model === ip_adapter_model?.base_model && .getSelectors()
c.model_name === ip_adapter_model?.model_name .selectById(
); ipAdapters,
`${ip_adapter_model.base_model}/ip_adapter/${ip_adapter_model.model_name}`
)
: undefined;
if (!matchingIPAdapterModel) { if (!matchingIPAdapterModel) {
return { ipAdapter: null, error: 'IP Adapter model is not installed' }; return { ipAdapter: null, error: 'IP Adapter model is not installed' };