mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: Parametrize useGetMainModelsQuery
This commit is contained in:
parent
7ce43692c2
commit
4d9a342437
@ -3,13 +3,8 @@ import { stateSelector } from 'app/store/store';
|
|||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import { validateSeedWeights } from 'common/util/seedWeightPairs';
|
import { validateSeedWeights } from 'common/util/seedWeightPairs';
|
||||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import {
|
import { modelsApi } from '../../services/api/endpoints/models';
|
||||||
modelsApi,
|
|
||||||
useGetMainModelsQuery,
|
|
||||||
} from '../../services/api/endpoints/models';
|
|
||||||
|
|
||||||
const readinessSelector = createSelector(
|
const readinessSelector = createSelector(
|
||||||
[stateSelector, activeTabNameSelector],
|
[stateSelector, activeTabNameSelector],
|
||||||
@ -38,7 +33,10 @@ const readinessSelector = createSelector(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const { isSuccess: mainModelsSuccessfullyLoaded } =
|
const { isSuccess: mainModelsSuccessfullyLoaded } =
|
||||||
modelsApi.endpoints.getMainModels.select()(state);
|
modelsApi.endpoints.getMainModels.select({
|
||||||
|
model_type: 'main',
|
||||||
|
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
|
||||||
|
})(state);
|
||||||
if (!mainModelsSuccessfullyLoaded) {
|
if (!mainModelsSuccessfullyLoaded) {
|
||||||
isReady = false;
|
isReady = false;
|
||||||
reasonsWhyNotReady.push('Models are not loaded');
|
reasonsWhyNotReady.push('Models are not loaded');
|
||||||
|
@ -22,7 +22,10 @@ const ModelInputFieldComponent = (
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const { data: mainModels } = useGetMainModelsQuery();
|
const { data: mainModels } = useGetMainModelsQuery({
|
||||||
|
model_type: 'main',
|
||||||
|
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
|
||||||
|
});
|
||||||
|
|
||||||
const data = useMemo(() => {
|
const data = useMemo(() => {
|
||||||
if (!mainModels) {
|
if (!mainModels) {
|
||||||
|
@ -25,7 +25,10 @@ const ModelSelect = () => {
|
|||||||
(state: RootState) => state.generation.model
|
(state: RootState) => state.generation.model
|
||||||
);
|
);
|
||||||
|
|
||||||
const { data: mainModels, isLoading } = useGetMainModelsQuery();
|
const { data: mainModels, isLoading } = useGetMainModelsQuery({
|
||||||
|
model_type: 'main',
|
||||||
|
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
|
||||||
|
});
|
||||||
|
|
||||||
const data = useMemo(() => {
|
const data = useMemo(() => {
|
||||||
if (!mainModels) {
|
if (!mainModels) {
|
||||||
|
@ -16,7 +16,10 @@ export default function MergeModelsPanel() {
|
|||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const { data } = useGetMainModelsQuery();
|
const { data } = useGetMainModelsQuery({
|
||||||
|
model_type: 'main',
|
||||||
|
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
|
||||||
|
});
|
||||||
|
|
||||||
const diffusersModels = pickBy(
|
const diffusersModels = pickBy(
|
||||||
data?.entities,
|
data?.entities,
|
||||||
|
@ -8,7 +8,10 @@ import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
|
|||||||
import ModelList from './ModelManagerPanel/ModelList';
|
import ModelList from './ModelManagerPanel/ModelList';
|
||||||
|
|
||||||
export default function ModelManagerPanel() {
|
export default function ModelManagerPanel() {
|
||||||
const { data: mainModels } = useGetMainModelsQuery();
|
const { data: mainModels } = useGetMainModelsQuery({
|
||||||
|
model_type: 'main',
|
||||||
|
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
|
||||||
|
});
|
||||||
|
|
||||||
const openModel = useAppSelector(
|
const openModel = useAppSelector(
|
||||||
(state: RootState) => state.system.openModel
|
(state: RootState) => state.system.openModel
|
||||||
|
@ -36,7 +36,10 @@ function ModelFilterButton({
|
|||||||
}
|
}
|
||||||
|
|
||||||
const ModelList = () => {
|
const ModelList = () => {
|
||||||
const { data: mainModels } = useGetMainModelsQuery();
|
const { data: mainModels } = useGetMainModelsQuery({
|
||||||
|
model_type: 'main',
|
||||||
|
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
|
||||||
|
});
|
||||||
|
|
||||||
const [renderModelList, setRenderModelList] = React.useState<boolean>(false);
|
const [renderModelList, setRenderModelList] = React.useState<boolean>(false);
|
||||||
|
|
||||||
|
@ -2,9 +2,11 @@ import { EntityState, createEntityAdapter } from '@reduxjs/toolkit';
|
|||||||
import { cloneDeep } from 'lodash-es';
|
import { cloneDeep } from 'lodash-es';
|
||||||
import {
|
import {
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
|
BaseModelType,
|
||||||
ControlNetModelConfig,
|
ControlNetModelConfig,
|
||||||
LoRAModelConfig,
|
LoRAModelConfig,
|
||||||
MainModelConfig,
|
MainModelConfig,
|
||||||
|
ModelType,
|
||||||
TextualInversionModelConfig,
|
TextualInversionModelConfig,
|
||||||
VaeModelConfig,
|
VaeModelConfig,
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
@ -68,21 +70,19 @@ const createModelEntities = <T extends AnyModelConfigEntity>(
|
|||||||
return entityArray;
|
return entityArray;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
type MainModelQueryArg = {
|
||||||
|
model_type: ModelType;
|
||||||
|
base_models: BaseModelType[];
|
||||||
|
};
|
||||||
|
|
||||||
export const modelsApi = api.injectEndpoints({
|
export const modelsApi = api.injectEndpoints({
|
||||||
endpoints: (build) => ({
|
endpoints: (build) => ({
|
||||||
getMainModels: build.query<EntityState<MainModelConfigEntity>, void>({
|
getMainModels: build.query<
|
||||||
query: () => {
|
EntityState<MainModelConfigEntity>,
|
||||||
const baseModels = {
|
MainModelQueryArg
|
||||||
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
|
>({
|
||||||
};
|
query: (arg: MainModelQueryArg) =>
|
||||||
const baseModelsQueryStr = queryString.stringify(baseModels, {});
|
`models/?${queryString.stringify(arg)}`,
|
||||||
return {
|
|
||||||
url: `models/?${baseModelsQueryStr}`,
|
|
||||||
params: {
|
|
||||||
model_type: 'main',
|
|
||||||
},
|
|
||||||
};
|
|
||||||
},
|
|
||||||
providesTags: (result, error, arg) => {
|
providesTags: (result, error, arg) => {
|
||||||
const tags: ApiFullTagDescription[] = [
|
const tags: ApiFullTagDescription[] = [
|
||||||
{ id: 'MainModel', type: LIST_TAG },
|
{ id: 'MainModel', type: LIST_TAG },
|
||||||
|
Loading…
Reference in New Issue
Block a user