mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): create rtk-query hooks for individual model types
Eg `useGetMainModelsQuery()`, `useGetLoRAModelsQuery()` instead of `useListModelsQuery({base_type})`. Add specific adapters for each model type. Just more organised and easier to consume models now. Also updated LoRA UI to use the model name.
This commit is contained in:
parent
c21b56ba31
commit
52a09422c7
@ -16,18 +16,18 @@ const ParamLora = (props: Props) => {
|
|||||||
|
|
||||||
const handleChange = useCallback(
|
const handleChange = useCallback(
|
||||||
(v: number) => {
|
(v: number) => {
|
||||||
dispatch(loraWeightChanged({ name: lora.name, weight: v }));
|
dispatch(loraWeightChanged({ id: lora.id, weight: v }));
|
||||||
},
|
},
|
||||||
[dispatch, lora.name]
|
[dispatch, lora.id]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleReset = useCallback(() => {
|
const handleReset = useCallback(() => {
|
||||||
dispatch(loraWeightChanged({ name: lora.name, weight: 1 }));
|
dispatch(loraWeightChanged({ id: lora.id, weight: 1 }));
|
||||||
}, [dispatch, lora.name]);
|
}, [dispatch, lora.id]);
|
||||||
|
|
||||||
const handleRemoveLora = useCallback(() => {
|
const handleRemoveLora = useCallback(() => {
|
||||||
dispatch(loraRemoved(lora.name));
|
dispatch(loraRemoved(lora.id));
|
||||||
}, [dispatch, lora.name]);
|
}, [dispatch, lora.id]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex sx={{ gap: 2.5, alignItems: 'flex-end' }}>
|
<Flex sx={{ gap: 2.5, alignItems: 'flex-end' }}>
|
||||||
|
@ -6,7 +6,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
|||||||
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
||||||
import { forEach } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import { forwardRef, useCallback, useMemo } from 'react';
|
import { forwardRef, useCallback, useMemo } from 'react';
|
||||||
import { useListModelsQuery } from 'services/api/endpoints/models';
|
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
||||||
import { loraAdded } from '../store/loraSlice';
|
import { loraAdded } from '../store/loraSlice';
|
||||||
|
|
||||||
type LoraSelectItem = {
|
type LoraSelectItem = {
|
||||||
@ -26,7 +26,7 @@ const selector = createSelector(
|
|||||||
const ParamLoraSelect = () => {
|
const ParamLoraSelect = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { loras } = useAppSelector(selector);
|
const { loras } = useAppSelector(selector);
|
||||||
const { data: lorasQueryData } = useListModelsQuery({ model_type: 'lora' });
|
const { data: lorasQueryData } = useGetLoRAModelsQuery();
|
||||||
|
|
||||||
const data = useMemo(() => {
|
const data = useMemo(() => {
|
||||||
if (!lorasQueryData) {
|
if (!lorasQueryData) {
|
||||||
@ -52,9 +52,13 @@ const ParamLoraSelect = () => {
|
|||||||
|
|
||||||
const handleChange = useCallback(
|
const handleChange = useCallback(
|
||||||
(v: string[]) => {
|
(v: string[]) => {
|
||||||
v[0] && dispatch(loraAdded(v[0]));
|
const loraEntity = lorasQueryData?.entities[v[0]];
|
||||||
|
if (!loraEntity) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
v[0] && dispatch(loraAdded(loraEntity));
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch, lorasQueryData?.entities]
|
||||||
);
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||||
|
import { LoRAModelConfigEntity } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
export type Lora = {
|
export type Lora = {
|
||||||
|
id: string;
|
||||||
name: string;
|
name: string;
|
||||||
weight: number;
|
weight: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const defaultLoRAConfig: Omit<Lora, 'name'> = {
|
export const defaultLoRAConfig: Omit<Lora, 'id' | 'name'> = {
|
||||||
weight: 1,
|
weight: 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -21,20 +23,20 @@ export const loraSlice = createSlice({
|
|||||||
name: 'lora',
|
name: 'lora',
|
||||||
initialState: intialLoraState,
|
initialState: intialLoraState,
|
||||||
reducers: {
|
reducers: {
|
||||||
loraAdded: (state, action: PayloadAction<string>) => {
|
loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => {
|
||||||
const name = action.payload;
|
const { name, id } = action.payload;
|
||||||
state.loras[name] = { name, ...defaultLoRAConfig };
|
state.loras[id] = { id, name, ...defaultLoRAConfig };
|
||||||
},
|
},
|
||||||
loraRemoved: (state, action: PayloadAction<string>) => {
|
loraRemoved: (state, action: PayloadAction<string>) => {
|
||||||
const name = action.payload;
|
const id = action.payload;
|
||||||
delete state.loras[name];
|
delete state.loras[id];
|
||||||
},
|
},
|
||||||
loraWeightChanged: (
|
loraWeightChanged: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<{ name: string; weight: number }>
|
action: PayloadAction<{ id: string; weight: number }>
|
||||||
) => {
|
) => {
|
||||||
const { name, weight } = action.payload;
|
const { id, weight } = action.payload;
|
||||||
state.loras[name].weight = weight;
|
state.loras[id].weight = weight;
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
@ -10,7 +10,7 @@ import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/component
|
|||||||
import { forEach, isString } from 'lodash-es';
|
import { forEach, isString } from 'lodash-es';
|
||||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useListModelsQuery } from 'services/api/endpoints/models';
|
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
||||||
import { FieldComponentProps } from './types';
|
import { FieldComponentProps } from './types';
|
||||||
|
|
||||||
const LoRAModelInputFieldComponent = (
|
const LoRAModelInputFieldComponent = (
|
||||||
@ -24,9 +24,7 @@ const LoRAModelInputFieldComponent = (
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const { data: loraModels } = useListModelsQuery({
|
const { data: loraModels } = useGetLoRAModelsQuery();
|
||||||
model_type: 'lora',
|
|
||||||
});
|
|
||||||
|
|
||||||
const selectedModel = useMemo(
|
const selectedModel = useMemo(
|
||||||
() => loraModels?.entities[field.value ?? loraModels.ids[0]],
|
() => loraModels?.entities[field.value ?? loraModels.ids[0]],
|
||||||
|
@ -11,7 +11,7 @@ import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/component
|
|||||||
import { forEach, isString } from 'lodash-es';
|
import { forEach, isString } from 'lodash-es';
|
||||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useListModelsQuery } from 'services/api/endpoints/models';
|
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||||
import { FieldComponentProps } from './types';
|
import { FieldComponentProps } from './types';
|
||||||
|
|
||||||
const ModelInputFieldComponent = (
|
const ModelInputFieldComponent = (
|
||||||
@ -22,9 +22,7 @@ const ModelInputFieldComponent = (
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const { data: mainModels } = useListModelsQuery({
|
const { data: mainModels } = useGetMainModelsQuery();
|
||||||
model_type: 'main',
|
|
||||||
});
|
|
||||||
|
|
||||||
const data = useMemo(() => {
|
const data = useMemo(() => {
|
||||||
if (!mainModels) {
|
if (!mainModels) {
|
||||||
|
@ -10,7 +10,7 @@ import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/component
|
|||||||
import { forEach } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useListModelsQuery } from 'services/api/endpoints/models';
|
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||||
import { FieldComponentProps } from './types';
|
import { FieldComponentProps } from './types';
|
||||||
|
|
||||||
const VaeModelInputFieldComponent = (
|
const VaeModelInputFieldComponent = (
|
||||||
@ -24,9 +24,7 @@ const VaeModelInputFieldComponent = (
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const { data: vaeModels } = useListModelsQuery({
|
const { data: vaeModels } = useGetVaeModelsQuery();
|
||||||
model_type: 'vae',
|
|
||||||
});
|
|
||||||
|
|
||||||
const selectedModel = useMemo(
|
const selectedModel = useMemo(
|
||||||
() => vaeModels?.entities[field.value ?? vaeModels.ids[0]],
|
() => vaeModels?.entities[field.value ?? vaeModels.ids[0]],
|
||||||
|
@ -8,7 +8,7 @@ import { modelSelected } from 'features/parameters/store/generationSlice';
|
|||||||
import { SelectItem } from '@mantine/core';
|
import { SelectItem } from '@mantine/core';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { forEach, isString } from 'lodash-es';
|
import { forEach, isString } from 'lodash-es';
|
||||||
import { useListModelsQuery } from 'services/api/endpoints/models';
|
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
export const MODEL_TYPE_MAP = {
|
export const MODEL_TYPE_MAP = {
|
||||||
'sd-1': 'Stable Diffusion 1.x',
|
'sd-1': 'Stable Diffusion 1.x',
|
||||||
@ -23,9 +23,7 @@ const ModelSelect = () => {
|
|||||||
(state: RootState) => state.generation.model
|
(state: RootState) => state.generation.model
|
||||||
);
|
);
|
||||||
|
|
||||||
const { data: mainModels, isLoading } = useListModelsQuery({
|
const { data: mainModels, isLoading } = useGetMainModelsQuery();
|
||||||
model_type: 'main',
|
|
||||||
});
|
|
||||||
|
|
||||||
const data = useMemo(() => {
|
const data = useMemo(() => {
|
||||||
if (!mainModels) {
|
if (!mainModels) {
|
||||||
|
@ -6,7 +6,7 @@ import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
|||||||
|
|
||||||
import { SelectItem } from '@mantine/core';
|
import { SelectItem } from '@mantine/core';
|
||||||
import { forEach } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import { useListModelsQuery } from 'services/api/endpoints/models';
|
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { vaeSelected } from 'features/parameters/store/generationSlice';
|
import { vaeSelected } from 'features/parameters/store/generationSlice';
|
||||||
@ -16,9 +16,7 @@ const VAESelect = () => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const { data: vaeModels } = useListModelsQuery({
|
const { data: vaeModels } = useGetVaeModelsQuery();
|
||||||
model_type: 'vae',
|
|
||||||
});
|
|
||||||
|
|
||||||
const selectedModelId = useAppSelector(
|
const selectedModelId = useAppSelector(
|
||||||
(state: RootState) => state.generation.vae
|
(state: RootState) => state.generation.vae
|
||||||
|
@ -9,16 +9,14 @@ import IAISlider from 'common/components/IAISlider';
|
|||||||
import { pickBy } from 'lodash-es';
|
import { pickBy } from 'lodash-es';
|
||||||
import { useState } from 'react';
|
import { useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useListModelsQuery } from 'services/api/endpoints/models';
|
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
export default function MergeModelsPanel() {
|
export default function MergeModelsPanel() {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const { data } = useListModelsQuery({
|
const { data } = useGetMainModelsQuery();
|
||||||
model_type: 'main',
|
|
||||||
});
|
|
||||||
|
|
||||||
const diffusersModels = pickBy(
|
const diffusersModels = pickBy(
|
||||||
data?.entities,
|
data?.entities,
|
||||||
|
@ -2,15 +2,13 @@ import { Flex } from '@chakra-ui/react';
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
|
||||||
import { useListModelsQuery } from 'services/api/endpoints/models';
|
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||||
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
|
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
|
||||||
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
|
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
|
||||||
import ModelList from './ModelManagerPanel/ModelList';
|
import ModelList from './ModelManagerPanel/ModelList';
|
||||||
|
|
||||||
export default function ModelManagerPanel() {
|
export default function ModelManagerPanel() {
|
||||||
const { data: mainModels } = useListModelsQuery({
|
const { data: mainModels } = useGetMainModelsQuery();
|
||||||
model_type: 'main',
|
|
||||||
});
|
|
||||||
|
|
||||||
const openModel = useAppSelector(
|
const openModel = useAppSelector(
|
||||||
(state: RootState) => state.system.openModel
|
(state: RootState) => state.system.openModel
|
||||||
|
@ -8,7 +8,7 @@ import { useTranslation } from 'react-i18next';
|
|||||||
|
|
||||||
import type { ChangeEvent, ReactNode } from 'react';
|
import type { ChangeEvent, ReactNode } from 'react';
|
||||||
import React, { useMemo, useState, useTransition } from 'react';
|
import React, { useMemo, useState, useTransition } from 'react';
|
||||||
import { useListModelsQuery } from 'services/api/endpoints/models';
|
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
function ModelFilterButton({
|
function ModelFilterButton({
|
||||||
label,
|
label,
|
||||||
@ -36,9 +36,7 @@ function ModelFilterButton({
|
|||||||
}
|
}
|
||||||
|
|
||||||
const ModelList = () => {
|
const ModelList = () => {
|
||||||
const { data: mainModels } = useListModelsQuery({
|
const { data: mainModels } = useGetMainModelsQuery();
|
||||||
model_type: 'main',
|
|
||||||
});
|
|
||||||
|
|
||||||
const [renderModelList, setRenderModelList] = React.useState<boolean>(false);
|
const [renderModelList, setRenderModelList] = React.useState<boolean>(false);
|
||||||
|
|
||||||
|
@ -1,35 +1,85 @@
|
|||||||
import { EntityState, createEntityAdapter } from '@reduxjs/toolkit';
|
import { EntityState, createEntityAdapter } from '@reduxjs/toolkit';
|
||||||
import { keyBy } from 'lodash-es';
|
import { cloneDeep } from 'lodash-es';
|
||||||
import { ModelsList } from 'services/api/types';
|
import {
|
||||||
|
AnyModelConfig,
|
||||||
|
ControlNetModelConfig,
|
||||||
|
LoRAModelConfig,
|
||||||
|
MainModelConfig,
|
||||||
|
TextualInversionModelConfig,
|
||||||
|
VaeModelConfig,
|
||||||
|
} from 'services/api/types';
|
||||||
|
|
||||||
import { ApiFullTagDescription, LIST_TAG, api } from '..';
|
import { ApiFullTagDescription, LIST_TAG, api } from '..';
|
||||||
import { paths } from '../schema';
|
|
||||||
|
|
||||||
type ModelConfig = ModelsList['models'][number];
|
export type MainModelConfigEntity = MainModelConfig & { id: string };
|
||||||
|
|
||||||
type ListModelsArg = NonNullable<
|
export type LoRAModelConfigEntity = LoRAModelConfig & { id: string };
|
||||||
paths['/api/v1/models/']['get']['parameters']['query']
|
|
||||||
>;
|
|
||||||
|
|
||||||
const modelsAdapter = createEntityAdapter<ModelConfig>({
|
export type ControlNetModelConfigEntity = ControlNetModelConfig & {
|
||||||
selectId: (model) => getModelId(model),
|
id: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type TextualInversionModelConfigEntity = TextualInversionModelConfig & {
|
||||||
|
id: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type VaeModelConfigEntity = VaeModelConfig & { id: string };
|
||||||
|
|
||||||
|
type AnyModelConfigEntity =
|
||||||
|
| MainModelConfigEntity
|
||||||
|
| LoRAModelConfigEntity
|
||||||
|
| ControlNetModelConfigEntity
|
||||||
|
| TextualInversionModelConfigEntity
|
||||||
|
| VaeModelConfigEntity;
|
||||||
|
|
||||||
|
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
||||||
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||||
|
});
|
||||||
|
const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
|
||||||
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||||
|
});
|
||||||
|
const controlNetModelsAdapter =
|
||||||
|
createEntityAdapter<ControlNetModelConfigEntity>({
|
||||||
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||||
|
});
|
||||||
|
const textualInversionModelsAdapter =
|
||||||
|
createEntityAdapter<TextualInversionModelConfigEntity>({
|
||||||
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||||
|
});
|
||||||
|
const vaeModelsAdapter = createEntityAdapter<VaeModelConfigEntity>({
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||||
});
|
});
|
||||||
|
|
||||||
const getModelId = ({ base_model, type, name }: ModelConfig) =>
|
export const getModelId = ({ base_model, type, name }: AnyModelConfig) =>
|
||||||
`${base_model}/${type}/${name}`;
|
`${base_model}/${type}/${name}`;
|
||||||
|
|
||||||
|
const createModelEntities = <T extends AnyModelConfigEntity>(
|
||||||
|
models: AnyModelConfig[]
|
||||||
|
): T[] => {
|
||||||
|
const entityArray: T[] = [];
|
||||||
|
models.forEach((model) => {
|
||||||
|
const entity = {
|
||||||
|
...cloneDeep(model),
|
||||||
|
id: getModelId(model),
|
||||||
|
} as T;
|
||||||
|
entityArray.push(entity);
|
||||||
|
});
|
||||||
|
return entityArray;
|
||||||
|
};
|
||||||
|
|
||||||
export const modelsApi = api.injectEndpoints({
|
export const modelsApi = api.injectEndpoints({
|
||||||
endpoints: (build) => ({
|
endpoints: (build) => ({
|
||||||
listModels: build.query<EntityState<ModelConfig>, ListModelsArg>({
|
getMainModels: build.query<EntityState<MainModelConfigEntity>, void>({
|
||||||
query: (arg) => ({ url: 'models/', params: arg }),
|
query: () => ({ url: 'models/', params: { model_type: 'main' } }),
|
||||||
providesTags: (result, error, arg) => {
|
providesTags: (result, error, arg) => {
|
||||||
const tags: ApiFullTagDescription[] = [{ id: 'Model', type: LIST_TAG }];
|
const tags: ApiFullTagDescription[] = [
|
||||||
|
{ id: 'MainModel', type: LIST_TAG },
|
||||||
|
];
|
||||||
|
|
||||||
if (result) {
|
if (result) {
|
||||||
tags.push(
|
tags.push(
|
||||||
...result.ids.map((id) => ({
|
...result.ids.map((id) => ({
|
||||||
type: 'Model' as const,
|
type: 'MainModel' as const,
|
||||||
id,
|
id,
|
||||||
}))
|
}))
|
||||||
);
|
);
|
||||||
@ -37,14 +87,161 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
|
|
||||||
return tags;
|
return tags;
|
||||||
},
|
},
|
||||||
transformResponse: (response: ModelsList, meta, arg) => {
|
transformResponse: (
|
||||||
return modelsAdapter.setAll(
|
response: { models: MainModelConfig[] },
|
||||||
modelsAdapter.getInitialState(),
|
meta,
|
||||||
keyBy(response.models, getModelId)
|
arg
|
||||||
|
) => {
|
||||||
|
const entities = createModelEntities<MainModelConfigEntity>(
|
||||||
|
response.models
|
||||||
|
);
|
||||||
|
return mainModelsAdapter.setAll(
|
||||||
|
mainModelsAdapter.getInitialState(),
|
||||||
|
entities
|
||||||
|
);
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
|
||||||
|
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
|
||||||
|
providesTags: (result, error, arg) => {
|
||||||
|
const tags: ApiFullTagDescription[] = [
|
||||||
|
{ id: 'LoRAModel', type: LIST_TAG },
|
||||||
|
];
|
||||||
|
|
||||||
|
if (result) {
|
||||||
|
tags.push(
|
||||||
|
...result.ids.map((id) => ({
|
||||||
|
type: 'LoRAModel' as const,
|
||||||
|
id,
|
||||||
|
}))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return tags;
|
||||||
|
},
|
||||||
|
transformResponse: (
|
||||||
|
response: { models: LoRAModelConfig[] },
|
||||||
|
meta,
|
||||||
|
arg
|
||||||
|
) => {
|
||||||
|
const entities = createModelEntities<LoRAModelConfigEntity>(
|
||||||
|
response.models
|
||||||
|
);
|
||||||
|
return loraModelsAdapter.setAll(
|
||||||
|
loraModelsAdapter.getInitialState(),
|
||||||
|
entities
|
||||||
|
);
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
getControlNetModels: build.query<
|
||||||
|
EntityState<ControlNetModelConfigEntity>,
|
||||||
|
void
|
||||||
|
>({
|
||||||
|
query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }),
|
||||||
|
providesTags: (result, error, arg) => {
|
||||||
|
const tags: ApiFullTagDescription[] = [
|
||||||
|
{ id: 'ControlNetModel', type: LIST_TAG },
|
||||||
|
];
|
||||||
|
|
||||||
|
if (result) {
|
||||||
|
tags.push(
|
||||||
|
...result.ids.map((id) => ({
|
||||||
|
type: 'ControlNetModel' as const,
|
||||||
|
id,
|
||||||
|
}))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return tags;
|
||||||
|
},
|
||||||
|
transformResponse: (
|
||||||
|
response: { models: ControlNetModelConfig[] },
|
||||||
|
meta,
|
||||||
|
arg
|
||||||
|
) => {
|
||||||
|
const entities = createModelEntities<ControlNetModelConfigEntity>(
|
||||||
|
response.models
|
||||||
|
);
|
||||||
|
return controlNetModelsAdapter.setAll(
|
||||||
|
controlNetModelsAdapter.getInitialState(),
|
||||||
|
entities
|
||||||
|
);
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
getVaeModels: build.query<EntityState<VaeModelConfigEntity>, void>({
|
||||||
|
query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
|
||||||
|
providesTags: (result, error, arg) => {
|
||||||
|
const tags: ApiFullTagDescription[] = [
|
||||||
|
{ id: 'VaeModel', type: LIST_TAG },
|
||||||
|
];
|
||||||
|
|
||||||
|
if (result) {
|
||||||
|
tags.push(
|
||||||
|
...result.ids.map((id) => ({
|
||||||
|
type: 'VaeModel' as const,
|
||||||
|
id,
|
||||||
|
}))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return tags;
|
||||||
|
},
|
||||||
|
transformResponse: (
|
||||||
|
response: { models: VaeModelConfig[] },
|
||||||
|
meta,
|
||||||
|
arg
|
||||||
|
) => {
|
||||||
|
const entities = createModelEntities<VaeModelConfigEntity>(
|
||||||
|
response.models
|
||||||
|
);
|
||||||
|
return vaeModelsAdapter.setAll(
|
||||||
|
vaeModelsAdapter.getInitialState(),
|
||||||
|
entities
|
||||||
|
);
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
getTextualInversionModels: build.query<
|
||||||
|
EntityState<TextualInversionModelConfigEntity>,
|
||||||
|
void
|
||||||
|
>({
|
||||||
|
query: () => ({ url: 'models/', params: { model_type: 'embedding' } }),
|
||||||
|
providesTags: (result, error, arg) => {
|
||||||
|
const tags: ApiFullTagDescription[] = [
|
||||||
|
{ id: 'TextualInversionModel', type: LIST_TAG },
|
||||||
|
];
|
||||||
|
|
||||||
|
if (result) {
|
||||||
|
tags.push(
|
||||||
|
...result.ids.map((id) => ({
|
||||||
|
type: 'TextualInversionModel' as const,
|
||||||
|
id,
|
||||||
|
}))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return tags;
|
||||||
|
},
|
||||||
|
transformResponse: (
|
||||||
|
response: { models: TextualInversionModelConfig[] },
|
||||||
|
meta,
|
||||||
|
arg
|
||||||
|
) => {
|
||||||
|
const entities = createModelEntities<TextualInversionModelConfigEntity>(
|
||||||
|
response.models
|
||||||
|
);
|
||||||
|
return textualInversionModelsAdapter.setAll(
|
||||||
|
textualInversionModelsAdapter.getInitialState(),
|
||||||
|
entities
|
||||||
);
|
);
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|
||||||
export const { useListModelsQuery } = modelsApi;
|
export const {
|
||||||
|
useGetMainModelsQuery,
|
||||||
|
useGetControlNetModelsQuery,
|
||||||
|
useGetLoRAModelsQuery,
|
||||||
|
useGetTextualInversionModelsQuery,
|
||||||
|
useGetVaeModelsQuery,
|
||||||
|
} = modelsApi;
|
||||||
|
214
invokeai/frontend/web/src/services/api/types.d.ts
vendored
214
invokeai/frontend/web/src/services/api/types.d.ts
vendored
@ -4,94 +4,156 @@ import { components } from './schema';
|
|||||||
type schemas = components['schemas'];
|
type schemas = components['schemas'];
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Extracts the schema type from the schema.
|
* Marks the `type` property as required. Use for nodes.
|
||||||
*/
|
*/
|
||||||
type S<T extends keyof components['schemas']> = components['schemas'][T];
|
type TypeReq<T> = O.Required<T, 'type'>;
|
||||||
|
|
||||||
/**
|
|
||||||
* Extracts the node type from the schema.
|
|
||||||
* Also flags the `type` property as required.
|
|
||||||
*/
|
|
||||||
type N<T extends keyof components['schemas']> = O.Required<
|
|
||||||
components['schemas'][T],
|
|
||||||
'type'
|
|
||||||
>;
|
|
||||||
|
|
||||||
// Images
|
// Images
|
||||||
export type ImageDTO = S<'ImageDTO'>;
|
export type ImageDTO = components['schemas']['ImageDTO'];
|
||||||
export type BoardDTO = S<'BoardDTO'>;
|
export type BoardDTO = components['schemas']['BoardDTO'];
|
||||||
export type BoardChanges = S<'BoardChanges'>;
|
export type BoardChanges = components['schemas']['BoardChanges'];
|
||||||
export type ImageChanges = S<'ImageRecordChanges'>;
|
export type ImageChanges = components['schemas']['ImageRecordChanges'];
|
||||||
export type ImageCategory = S<'ImageCategory'>;
|
export type ImageCategory = components['schemas']['ImageCategory'];
|
||||||
export type ResourceOrigin = S<'ResourceOrigin'>;
|
export type ResourceOrigin = components['schemas']['ResourceOrigin'];
|
||||||
export type ImageField = S<'ImageField'>;
|
export type ImageField = components['schemas']['ImageField'];
|
||||||
export type OffsetPaginatedResults_BoardDTO_ =
|
export type OffsetPaginatedResults_BoardDTO_ =
|
||||||
S<'OffsetPaginatedResults_BoardDTO_'>;
|
components['schemas']['OffsetPaginatedResults_BoardDTO_'];
|
||||||
export type OffsetPaginatedResults_ImageDTO_ =
|
export type OffsetPaginatedResults_ImageDTO_ =
|
||||||
S<'OffsetPaginatedResults_ImageDTO_'>;
|
components['schemas']['OffsetPaginatedResults_ImageDTO_'];
|
||||||
|
|
||||||
// Models
|
// Models
|
||||||
export type ModelType = S<'ModelType'>;
|
export type ModelType = components['schemas']['ModelType'];
|
||||||
export type BaseModelType = S<'BaseModelType'>;
|
export type BaseModelType = components['schemas']['BaseModelType'];
|
||||||
export type MainModelField = S<'MainModelField'>;
|
export type MainModelField = components['schemas']['MainModelField'];
|
||||||
export type VAEModelField = S<'VAEModelField'>;
|
export type VAEModelField = components['schemas']['VAEModelField'];
|
||||||
export type LoRAModelField = S<'LoRAModelField'>;
|
export type LoRAModelField = components['schemas']['LoRAModelField'];
|
||||||
export type ModelsList = S<'ModelsList'>;
|
export type ModelsList = components['schemas']['ModelsList'];
|
||||||
export type LoRAModelConfig = S<'LoRAModelConfig'>;
|
|
||||||
|
// Model Configs
|
||||||
|
export type LoRAModelConfig = components['schemas']['LoRAModelConfig'];
|
||||||
|
export type VaeModelConfig = components['schemas']['VaeModelConfig'];
|
||||||
|
export type ControlNetModelConfig =
|
||||||
|
components['schemas']['ControlNetModelConfig'];
|
||||||
|
export type TextualInversionModelConfig =
|
||||||
|
components['schemas']['TextualInversionModelConfig'];
|
||||||
|
export type MainModelConfig =
|
||||||
|
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
|
||||||
|
| components['schemas']['StableDiffusion1ModelDiffusersConfig']
|
||||||
|
| components['schemas']['StableDiffusion2ModelCheckpointConfig']
|
||||||
|
| components['schemas']['StableDiffusion2ModelDiffusersConfig'];
|
||||||
|
export type AnyModelConfig =
|
||||||
|
| LoRAModelConfig
|
||||||
|
| VaeModelConfig
|
||||||
|
| ControlNetModelConfig
|
||||||
|
| TextualInversionModelConfig
|
||||||
|
| MainModelConfig;
|
||||||
|
|
||||||
// Graphs
|
// Graphs
|
||||||
export type Graph = S<'Graph'>;
|
export type Graph = components['schemas']['Graph'];
|
||||||
export type Edge = S<'Edge'>;
|
export type Edge = components['schemas']['Edge'];
|
||||||
export type GraphExecutionState = S<'GraphExecutionState'>;
|
export type GraphExecutionState = components['schemas']['GraphExecutionState'];
|
||||||
|
|
||||||
// General nodes
|
// General nodes
|
||||||
export type CollectInvocation = N<'CollectInvocation'>;
|
export type CollectInvocation = TypeReq<
|
||||||
export type IterateInvocation = N<'IterateInvocation'>;
|
components['schemas']['CollectInvocation']
|
||||||
export type RangeInvocation = N<'RangeInvocation'>;
|
>;
|
||||||
export type RandomRangeInvocation = N<'RandomRangeInvocation'>;
|
export type IterateInvocation = TypeReq<
|
||||||
export type RangeOfSizeInvocation = N<'RangeOfSizeInvocation'>;
|
components['schemas']['IterateInvocation']
|
||||||
export type InpaintInvocation = N<'InpaintInvocation'>;
|
>;
|
||||||
export type ImageResizeInvocation = N<'ImageResizeInvocation'>;
|
export type RangeInvocation = TypeReq<components['schemas']['RangeInvocation']>;
|
||||||
export type RandomIntInvocation = N<'RandomIntInvocation'>;
|
export type RandomRangeInvocation = TypeReq<
|
||||||
export type CompelInvocation = N<'CompelInvocation'>;
|
components['schemas']['RandomRangeInvocation']
|
||||||
export type DynamicPromptInvocation = N<'DynamicPromptInvocation'>;
|
>;
|
||||||
export type NoiseInvocation = N<'NoiseInvocation'>;
|
export type RangeOfSizeInvocation = TypeReq<
|
||||||
export type TextToLatentsInvocation = N<'TextToLatentsInvocation'>;
|
components['schemas']['RangeOfSizeInvocation']
|
||||||
export type LatentsToLatentsInvocation = N<'LatentsToLatentsInvocation'>;
|
>;
|
||||||
export type ImageToLatentsInvocation = N<'ImageToLatentsInvocation'>;
|
export type InpaintInvocation = TypeReq<
|
||||||
export type LatentsToImageInvocation = N<'LatentsToImageInvocation'>;
|
components['schemas']['InpaintInvocation']
|
||||||
export type ImageCollectionInvocation = N<'ImageCollectionInvocation'>;
|
>;
|
||||||
export type MainModelLoaderInvocation = N<'MainModelLoaderInvocation'>;
|
export type ImageResizeInvocation = TypeReq<
|
||||||
export type LoraLoaderInvocation = N<'LoraLoaderInvocation'>;
|
components['schemas']['ImageResizeInvocation']
|
||||||
|
>;
|
||||||
|
export type RandomIntInvocation = TypeReq<
|
||||||
|
components['schemas']['RandomIntInvocation']
|
||||||
|
>;
|
||||||
|
export type CompelInvocation = TypeReq<
|
||||||
|
components['schemas']['CompelInvocation']
|
||||||
|
>;
|
||||||
|
export type DynamicPromptInvocation = TypeReq<
|
||||||
|
components['schemas']['DynamicPromptInvocation']
|
||||||
|
>;
|
||||||
|
export type NoiseInvocation = TypeReq<components['schemas']['NoiseInvocation']>;
|
||||||
|
export type TextToLatentsInvocation = TypeReq<
|
||||||
|
components['schemas']['TextToLatentsInvocation']
|
||||||
|
>;
|
||||||
|
export type LatentsToLatentsInvocation = TypeReq<
|
||||||
|
components['schemas']['LatentsToLatentsInvocation']
|
||||||
|
>;
|
||||||
|
export type ImageToLatentsInvocation = TypeReq<
|
||||||
|
components['schemas']['ImageToLatentsInvocation']
|
||||||
|
>;
|
||||||
|
export type LatentsToImageInvocation = TypeReq<
|
||||||
|
components['schemas']['LatentsToImageInvocation']
|
||||||
|
>;
|
||||||
|
export type ImageCollectionInvocation = TypeReq<
|
||||||
|
components['schemas']['ImageCollectionInvocation']
|
||||||
|
>;
|
||||||
|
export type MainModelLoaderInvocation = TypeReq<
|
||||||
|
components['schemas']['MainModelLoaderInvocation']
|
||||||
|
>;
|
||||||
|
export type LoraLoaderInvocation = TypeReq<
|
||||||
|
components['schemas']['LoraLoaderInvocation']
|
||||||
|
>;
|
||||||
|
|
||||||
// ControlNet Nodes
|
// ControlNet Nodes
|
||||||
export type ControlNetInvocation = N<'ControlNetInvocation'>;
|
export type ControlNetInvocation = TypeReq<
|
||||||
export type CannyImageProcessorInvocation = N<'CannyImageProcessorInvocation'>;
|
components['schemas']['ControlNetInvocation']
|
||||||
export type ContentShuffleImageProcessorInvocation =
|
>;
|
||||||
N<'ContentShuffleImageProcessorInvocation'>;
|
export type CannyImageProcessorInvocation = TypeReq<
|
||||||
export type HedImageProcessorInvocation = N<'HedImageProcessorInvocation'>;
|
components['schemas']['CannyImageProcessorInvocation']
|
||||||
export type LineartAnimeImageProcessorInvocation =
|
>;
|
||||||
N<'LineartAnimeImageProcessorInvocation'>;
|
export type ContentShuffleImageProcessorInvocation = TypeReq<
|
||||||
export type LineartImageProcessorInvocation =
|
components['schemas']['ContentShuffleImageProcessorInvocation']
|
||||||
N<'LineartImageProcessorInvocation'>;
|
>;
|
||||||
export type MediapipeFaceProcessorInvocation =
|
export type HedImageProcessorInvocation = TypeReq<
|
||||||
N<'MediapipeFaceProcessorInvocation'>;
|
components['schemas']['HedImageProcessorInvocation']
|
||||||
export type MidasDepthImageProcessorInvocation =
|
>;
|
||||||
N<'MidasDepthImageProcessorInvocation'>;
|
export type LineartAnimeImageProcessorInvocation = TypeReq<
|
||||||
export type MlsdImageProcessorInvocation = N<'MlsdImageProcessorInvocation'>;
|
components['schemas']['LineartAnimeImageProcessorInvocation']
|
||||||
export type NormalbaeImageProcessorInvocation =
|
>;
|
||||||
N<'NormalbaeImageProcessorInvocation'>;
|
export type LineartImageProcessorInvocation = TypeReq<
|
||||||
export type OpenposeImageProcessorInvocation =
|
components['schemas']['LineartImageProcessorInvocation']
|
||||||
N<'OpenposeImageProcessorInvocation'>;
|
>;
|
||||||
export type PidiImageProcessorInvocation = N<'PidiImageProcessorInvocation'>;
|
export type MediapipeFaceProcessorInvocation = TypeReq<
|
||||||
export type ZoeDepthImageProcessorInvocation =
|
components['schemas']['MediapipeFaceProcessorInvocation']
|
||||||
N<'ZoeDepthImageProcessorInvocation'>;
|
>;
|
||||||
|
export type MidasDepthImageProcessorInvocation = TypeReq<
|
||||||
|
components['schemas']['MidasDepthImageProcessorInvocation']
|
||||||
|
>;
|
||||||
|
export type MlsdImageProcessorInvocation = TypeReq<
|
||||||
|
components['schemas']['MlsdImageProcessorInvocation']
|
||||||
|
>;
|
||||||
|
export type NormalbaeImageProcessorInvocation = TypeReq<
|
||||||
|
components['schemas']['NormalbaeImageProcessorInvocation']
|
||||||
|
>;
|
||||||
|
export type OpenposeImageProcessorInvocation = TypeReq<
|
||||||
|
components['schemas']['OpenposeImageProcessorInvocation']
|
||||||
|
>;
|
||||||
|
export type PidiImageProcessorInvocation = TypeReq<
|
||||||
|
components['schemas']['PidiImageProcessorInvocation']
|
||||||
|
>;
|
||||||
|
export type ZoeDepthImageProcessorInvocation = TypeReq<
|
||||||
|
components['schemas']['ZoeDepthImageProcessorInvocation']
|
||||||
|
>;
|
||||||
|
|
||||||
// Node Outputs
|
// Node Outputs
|
||||||
export type ImageOutput = S<'ImageOutput'>;
|
export type ImageOutput = components['schemas']['ImageOutput'];
|
||||||
export type MaskOutput = S<'MaskOutput'>;
|
export type MaskOutput = components['schemas']['MaskOutput'];
|
||||||
export type PromptOutput = S<'PromptOutput'>;
|
export type PromptOutput = components['schemas']['PromptOutput'];
|
||||||
export type IterateInvocationOutput = S<'IterateInvocationOutput'>;
|
export type IterateInvocationOutput =
|
||||||
export type CollectInvocationOutput = S<'CollectInvocationOutput'>;
|
components['schemas']['IterateInvocationOutput'];
|
||||||
export type LatentsOutput = S<'LatentsOutput'>;
|
export type CollectInvocationOutput =
|
||||||
export type GraphInvocationOutput = S<'GraphInvocationOutput'>;
|
components['schemas']['CollectInvocationOutput'];
|
||||||
|
export type LatentsOutput = components['schemas']['LatentsOutput'];
|
||||||
|
export type GraphInvocationOutput =
|
||||||
|
components['schemas']['GraphInvocationOutput'];
|
||||||
|
Loading…
Reference in New Issue
Block a user