feat: Parametrize useGetMainModelsQuery

This commit is contained in:
blessedcoolant 2023-07-11 16:33:26 +12:00
parent 7ce43692c2
commit 4d9a342437
7 changed files with 38 additions and 25 deletions

View File

@ -3,13 +3,8 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { validateSeedWeights } from 'common/util/seedWeightPairs'; import { validateSeedWeights } from 'common/util/seedWeightPairs';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { systemSelector } from 'features/system/store/systemSelectors';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { import { modelsApi } from '../../services/api/endpoints/models';
modelsApi,
useGetMainModelsQuery,
} from '../../services/api/endpoints/models';
const readinessSelector = createSelector( const readinessSelector = createSelector(
[stateSelector, activeTabNameSelector], [stateSelector, activeTabNameSelector],
@ -38,7 +33,10 @@ const readinessSelector = createSelector(
} }
const { isSuccess: mainModelsSuccessfullyLoaded } = const { isSuccess: mainModelsSuccessfullyLoaded } =
modelsApi.endpoints.getMainModels.select()(state); modelsApi.endpoints.getMainModels.select({
model_type: 'main',
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
})(state);
if (!mainModelsSuccessfullyLoaded) { if (!mainModelsSuccessfullyLoaded) {
isReady = false; isReady = false;
reasonsWhyNotReady.push('Models are not loaded'); reasonsWhyNotReady.push('Models are not loaded');

View File

@ -22,7 +22,10 @@ const ModelInputFieldComponent = (
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const { data: mainModels } = useGetMainModelsQuery(); const { data: mainModels } = useGetMainModelsQuery({
model_type: 'main',
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
});
const data = useMemo(() => { const data = useMemo(() => {
if (!mainModels) { if (!mainModels) {

View File

@ -25,7 +25,10 @@ const ModelSelect = () => {
(state: RootState) => state.generation.model (state: RootState) => state.generation.model
); );
const { data: mainModels, isLoading } = useGetMainModelsQuery(); const { data: mainModels, isLoading } = useGetMainModelsQuery({
model_type: 'main',
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
});
const data = useMemo(() => { const data = useMemo(() => {
if (!mainModels) { if (!mainModels) {

View File

@ -16,7 +16,10 @@ export default function MergeModelsPanel() {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data } = useGetMainModelsQuery(); const { data } = useGetMainModelsQuery({
model_type: 'main',
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
});
const diffusersModels = pickBy( const diffusersModels = pickBy(
data?.entities, data?.entities,

View File

@ -8,7 +8,10 @@ import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
import ModelList from './ModelManagerPanel/ModelList'; import ModelList from './ModelManagerPanel/ModelList';
export default function ModelManagerPanel() { export default function ModelManagerPanel() {
const { data: mainModels } = useGetMainModelsQuery(); const { data: mainModels } = useGetMainModelsQuery({
model_type: 'main',
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
});
const openModel = useAppSelector( const openModel = useAppSelector(
(state: RootState) => state.system.openModel (state: RootState) => state.system.openModel

View File

@ -36,7 +36,10 @@ function ModelFilterButton({
} }
const ModelList = () => { const ModelList = () => {
const { data: mainModels } = useGetMainModelsQuery(); const { data: mainModels } = useGetMainModelsQuery({
model_type: 'main',
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
});
const [renderModelList, setRenderModelList] = React.useState<boolean>(false); const [renderModelList, setRenderModelList] = React.useState<boolean>(false);

View File

@ -2,9 +2,11 @@ import { EntityState, createEntityAdapter } from '@reduxjs/toolkit';
import { cloneDeep } from 'lodash-es'; import { cloneDeep } from 'lodash-es';
import { import {
AnyModelConfig, AnyModelConfig,
BaseModelType,
ControlNetModelConfig, ControlNetModelConfig,
LoRAModelConfig, LoRAModelConfig,
MainModelConfig, MainModelConfig,
ModelType,
TextualInversionModelConfig, TextualInversionModelConfig,
VaeModelConfig, VaeModelConfig,
} from 'services/api/types'; } from 'services/api/types';
@ -68,21 +70,19 @@ const createModelEntities = <T extends AnyModelConfigEntity>(
return entityArray; return entityArray;
}; };
type MainModelQueryArg = {
model_type: ModelType;
base_models: BaseModelType[];
};
export const modelsApi = api.injectEndpoints({ export const modelsApi = api.injectEndpoints({
endpoints: (build) => ({ endpoints: (build) => ({
getMainModels: build.query<EntityState<MainModelConfigEntity>, void>({ getMainModels: build.query<
query: () => { EntityState<MainModelConfigEntity>,
const baseModels = { MainModelQueryArg
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'], >({
}; query: (arg: MainModelQueryArg) =>
const baseModelsQueryStr = queryString.stringify(baseModels, {}); `models/?${queryString.stringify(arg)}`,
return {
url: `models/?${baseModelsQueryStr}`,
params: {
model_type: 'main',
},
};
},
providesTags: (result, error, arg) => { providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [ const tags: ApiFullTagDescription[] = [
{ id: 'MainModel', type: LIST_TAG }, { id: 'MainModel', type: LIST_TAG },