feat: Parametrize useGetMainModelsQuery

This commit is contained in:
blessedcoolant 2023-07-11 16:33:26 +12:00
parent 7ce43692c2
commit 4d9a342437
7 changed files with 38 additions and 25 deletions

View File

@ -3,13 +3,8 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { validateSeedWeights } from 'common/util/seedWeightPairs';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { systemSelector } from 'features/system/store/systemSelectors';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import {
modelsApi,
useGetMainModelsQuery,
} from '../../services/api/endpoints/models';
import { modelsApi } from '../../services/api/endpoints/models';
const readinessSelector = createSelector(
[stateSelector, activeTabNameSelector],
@ -38,7 +33,10 @@ const readinessSelector = createSelector(
}
const { isSuccess: mainModelsSuccessfullyLoaded } =
modelsApi.endpoints.getMainModels.select()(state);
modelsApi.endpoints.getMainModels.select({
model_type: 'main',
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
})(state);
if (!mainModelsSuccessfullyLoaded) {
isReady = false;
reasonsWhyNotReady.push('Models are not loaded');

View File

@ -22,7 +22,10 @@ const ModelInputFieldComponent = (
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { data: mainModels } = useGetMainModelsQuery();
const { data: mainModels } = useGetMainModelsQuery({
model_type: 'main',
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
});
const data = useMemo(() => {
if (!mainModels) {

View File

@ -25,7 +25,10 @@ const ModelSelect = () => {
(state: RootState) => state.generation.model
);
const { data: mainModels, isLoading } = useGetMainModelsQuery();
const { data: mainModels, isLoading } = useGetMainModelsQuery({
model_type: 'main',
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
});
const data = useMemo(() => {
if (!mainModels) {

View File

@ -16,7 +16,10 @@ export default function MergeModelsPanel() {
const dispatch = useAppDispatch();
const { data } = useGetMainModelsQuery();
const { data } = useGetMainModelsQuery({
model_type: 'main',
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
});
const diffusersModels = pickBy(
data?.entities,

View File

@ -8,7 +8,10 @@ import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
import ModelList from './ModelManagerPanel/ModelList';
export default function ModelManagerPanel() {
const { data: mainModels } = useGetMainModelsQuery();
const { data: mainModels } = useGetMainModelsQuery({
model_type: 'main',
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
});
const openModel = useAppSelector(
(state: RootState) => state.system.openModel

View File

@ -36,7 +36,10 @@ function ModelFilterButton({
}
const ModelList = () => {
const { data: mainModels } = useGetMainModelsQuery();
const { data: mainModels } = useGetMainModelsQuery({
model_type: 'main',
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
});
const [renderModelList, setRenderModelList] = React.useState<boolean>(false);

View File

@ -2,9 +2,11 @@ import { EntityState, createEntityAdapter } from '@reduxjs/toolkit';
import { cloneDeep } from 'lodash-es';
import {
AnyModelConfig,
BaseModelType,
ControlNetModelConfig,
LoRAModelConfig,
MainModelConfig,
ModelType,
TextualInversionModelConfig,
VaeModelConfig,
} from 'services/api/types';
@ -68,21 +70,19 @@ const createModelEntities = <T extends AnyModelConfigEntity>(
return entityArray;
};
type MainModelQueryArg = {
model_type: ModelType;
base_models: BaseModelType[];
};
export const modelsApi = api.injectEndpoints({
endpoints: (build) => ({
getMainModels: build.query<EntityState<MainModelConfigEntity>, void>({
query: () => {
const baseModels = {
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
};
const baseModelsQueryStr = queryString.stringify(baseModels, {});
return {
url: `models/?${baseModelsQueryStr}`,
params: {
model_type: 'main',
},
};
},
getMainModels: build.query<
EntityState<MainModelConfigEntity>,
MainModelQueryArg
>({
query: (arg: MainModelQueryArg) =>
`models/?${queryString.stringify(arg)}`,
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ id: 'MainModel', type: LIST_TAG },