mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: Parametrize useGetMainModelsQuery
This commit is contained in:
parent
7ce43692c2
commit
4d9a342437
@ -3,13 +3,8 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { validateSeedWeights } from 'common/util/seedWeightPairs';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import {
|
||||
modelsApi,
|
||||
useGetMainModelsQuery,
|
||||
} from '../../services/api/endpoints/models';
|
||||
import { modelsApi } from '../../services/api/endpoints/models';
|
||||
|
||||
const readinessSelector = createSelector(
|
||||
[stateSelector, activeTabNameSelector],
|
||||
@ -38,7 +33,10 @@ const readinessSelector = createSelector(
|
||||
}
|
||||
|
||||
const { isSuccess: mainModelsSuccessfullyLoaded } =
|
||||
modelsApi.endpoints.getMainModels.select()(state);
|
||||
modelsApi.endpoints.getMainModels.select({
|
||||
model_type: 'main',
|
||||
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
|
||||
})(state);
|
||||
if (!mainModelsSuccessfullyLoaded) {
|
||||
isReady = false;
|
||||
reasonsWhyNotReady.push('Models are not loaded');
|
||||
|
@ -22,7 +22,10 @@ const ModelInputFieldComponent = (
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { data: mainModels } = useGetMainModelsQuery();
|
||||
const { data: mainModels } = useGetMainModelsQuery({
|
||||
model_type: 'main',
|
||||
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
|
||||
});
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!mainModels) {
|
||||
|
@ -25,7 +25,10 @@ const ModelSelect = () => {
|
||||
(state: RootState) => state.generation.model
|
||||
);
|
||||
|
||||
const { data: mainModels, isLoading } = useGetMainModelsQuery();
|
||||
const { data: mainModels, isLoading } = useGetMainModelsQuery({
|
||||
model_type: 'main',
|
||||
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
|
||||
});
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!mainModels) {
|
||||
|
@ -16,7 +16,10 @@ export default function MergeModelsPanel() {
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { data } = useGetMainModelsQuery();
|
||||
const { data } = useGetMainModelsQuery({
|
||||
model_type: 'main',
|
||||
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
|
||||
});
|
||||
|
||||
const diffusersModels = pickBy(
|
||||
data?.entities,
|
||||
|
@ -8,7 +8,10 @@ import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
|
||||
import ModelList from './ModelManagerPanel/ModelList';
|
||||
|
||||
export default function ModelManagerPanel() {
|
||||
const { data: mainModels } = useGetMainModelsQuery();
|
||||
const { data: mainModels } = useGetMainModelsQuery({
|
||||
model_type: 'main',
|
||||
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
|
||||
});
|
||||
|
||||
const openModel = useAppSelector(
|
||||
(state: RootState) => state.system.openModel
|
||||
|
@ -36,7 +36,10 @@ function ModelFilterButton({
|
||||
}
|
||||
|
||||
const ModelList = () => {
|
||||
const { data: mainModels } = useGetMainModelsQuery();
|
||||
const { data: mainModels } = useGetMainModelsQuery({
|
||||
model_type: 'main',
|
||||
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
|
||||
});
|
||||
|
||||
const [renderModelList, setRenderModelList] = React.useState<boolean>(false);
|
||||
|
||||
|
@ -2,9 +2,11 @@ import { EntityState, createEntityAdapter } from '@reduxjs/toolkit';
|
||||
import { cloneDeep } from 'lodash-es';
|
||||
import {
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ControlNetModelConfig,
|
||||
LoRAModelConfig,
|
||||
MainModelConfig,
|
||||
ModelType,
|
||||
TextualInversionModelConfig,
|
||||
VaeModelConfig,
|
||||
} from 'services/api/types';
|
||||
@ -68,21 +70,19 @@ const createModelEntities = <T extends AnyModelConfigEntity>(
|
||||
return entityArray;
|
||||
};
|
||||
|
||||
type MainModelQueryArg = {
|
||||
model_type: ModelType;
|
||||
base_models: BaseModelType[];
|
||||
};
|
||||
|
||||
export const modelsApi = api.injectEndpoints({
|
||||
endpoints: (build) => ({
|
||||
getMainModels: build.query<EntityState<MainModelConfigEntity>, void>({
|
||||
query: () => {
|
||||
const baseModels = {
|
||||
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
|
||||
};
|
||||
const baseModelsQueryStr = queryString.stringify(baseModels, {});
|
||||
return {
|
||||
url: `models/?${baseModelsQueryStr}`,
|
||||
params: {
|
||||
model_type: 'main',
|
||||
},
|
||||
};
|
||||
},
|
||||
getMainModels: build.query<
|
||||
EntityState<MainModelConfigEntity>,
|
||||
MainModelQueryArg
|
||||
>({
|
||||
query: (arg: MainModelQueryArg) =>
|
||||
`models/?${queryString.stringify(arg)}`,
|
||||
providesTags: (result, error, arg) => {
|
||||
const tags: ApiFullTagDescription[] = [
|
||||
{ id: 'MainModel', type: LIST_TAG },
|
||||
|
Loading…
Reference in New Issue
Block a user