perf(mm): add manual query cache updates for the update model route

This greatly reduces the number of network requests when editing models.
This commit is contained in:
psychedelicious 2024-03-07 18:18:32 +11:00
parent ff66779aa3
commit 0aa2070ce0

View File

@ -2,6 +2,13 @@ import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@
import { createEntityAdapter } from '@reduxjs/toolkit';
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import queryString from 'query-string';
import {
ALL_BASE_MODELS,
NON_REFINER_BASE_MODELS,
NON_SDXL_MAIN_MODELS,
REFINER_BASE_MODELS,
SDXL_MAIN_MODELS,
} from 'services/api/constants';
import type { operations, paths } from 'services/api/schema';
import type {
AnyModelConfig,
@ -150,7 +157,11 @@ export const modelsApi = api.injectEndpoints({
body: body,
};
},
invalidatesTags: ['Model'],
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertSingleModelConfig(data, dispatch);
});
},
}),
updateModelImage: build.mutation<UpdateModelImageResponse, UpdateModelImageArg>({
query: ({ key, image }) => {
@ -379,9 +390,121 @@ const upsertModelConfigs = (
// eslint-disable-next-line @typescript-eslint/no-explicit-any
dispatch: ThunkDispatch<any, any, UnknownAction>
) => {
/**
* Once a list of models of a specific type is received, fetching any of those models individually is a waste of a
* network request. This function takes the received list of models and upserts them into the individual query caches
* for each model type.
*/
// Iterate over all the models and upsert them into the individual query caches for each model type.
anyModelConfigAdapterSelectors.selectAll(modelConfigs).forEach((modelConfig) => {
dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig));
const { base, name, type } = modelConfig;
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig));
});
};
const upsertSingleModelConfig = (
modelConfig: AnyModelConfig,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
dispatch: ThunkDispatch<any, any, UnknownAction>
) => {
/**
* When a model is updated, the individual query caches for each model type need to be updated, as well as the list
* query caches of models of that type.
*/
// Update the individual model query caches.
dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig));
const { base, name, type } = modelConfig;
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig));
// Update the list query caches for each model type.
if (modelConfig.type === 'main') {
[ALL_BASE_MODELS, NON_REFINER_BASE_MODELS, SDXL_MAIN_MODELS, NON_SDXL_MAIN_MODELS, REFINER_BASE_MODELS].forEach(
(queryArg) => {
dispatch(
modelsApi.util.updateQueryData('getMainModels', queryArg, (draft) => {
mainModelsAdapter.updateOne(draft, {
id: modelConfig.key,
changes: modelConfig,
});
})
);
}
);
return;
}
if (modelConfig.type === 'controlnet') {
dispatch(
modelsApi.util.updateQueryData('getControlNetModels', undefined, (draft) => {
controlNetModelsAdapter.updateOne(draft, {
id: modelConfig.key,
changes: modelConfig,
});
})
);
return;
}
if (modelConfig.type === 'embedding') {
dispatch(
modelsApi.util.updateQueryData('getTextualInversionModels', undefined, (draft) => {
textualInversionModelsAdapter.updateOne(draft, {
id: modelConfig.key,
changes: modelConfig,
});
})
);
return;
}
if (modelConfig.type === 'ip_adapter') {
dispatch(
modelsApi.util.updateQueryData('getIPAdapterModels', undefined, (draft) => {
ipAdapterModelsAdapter.updateOne(draft, {
id: modelConfig.key,
changes: modelConfig,
});
})
);
return;
}
if (modelConfig.type === 'lora') {
dispatch(
modelsApi.util.updateQueryData('getLoRAModels', undefined, (draft) => {
loraModelsAdapter.updateOne(draft, {
id: modelConfig.key,
changes: modelConfig,
});
})
);
return;
}
if (modelConfig.type === 't2i_adapter') {
dispatch(
modelsApi.util.updateQueryData('getT2IAdapterModels', undefined, (draft) => {
t2iAdapterModelsAdapter.updateOne(draft, {
id: modelConfig.key,
changes: modelConfig,
});
})
);
return;
}
if (modelConfig.type === 'vae') {
dispatch(
modelsApi.util.updateQueryData('getVaeModels', undefined, (draft) => {
vaeModelsAdapter.updateOne(draft, {
id: modelConfig.key,
changes: modelConfig,
});
})
);
return;
}
};