mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): fix workflow editor model selector, excise ONNX
Ensure workflow editor model selector component gets a value This introduced some funky type issues related to ONNX models. ONNX doesn't work anyways (unmaintained). Instead of fixing the types to work with a non-working feature, ONNX is now removed entirely from the UI. - Remove all refs to ONNX (and Olives) - Fix some type issues - Add ONNX nodes to the nodes denylist (so they are not visible in UI) - Update VAE graph helper, which still had some ONNX logic. It's a very simple change and doesn't change any logic. Just removes some conditions that were for ONNX. I tested it and nothing broke. - Regenerate types - Fix prettier and eslint ignores for generated types - Lint
This commit is contained in:
parent
ebe717099e
commit
b1b5c0d3b2
@ -7,4 +7,4 @@ stats.html
|
|||||||
index.html
|
index.html
|
||||||
.yarn/
|
.yarn/
|
||||||
*.scss
|
*.scss
|
||||||
src/services/api/schema.d.ts
|
src/services/api/schema.ts
|
||||||
|
@ -9,7 +9,7 @@ index.html
|
|||||||
.yarn/
|
.yarn/
|
||||||
.yalc/
|
.yalc/
|
||||||
*.scss
|
*.scss
|
||||||
src/services/api/schema.d.ts
|
src/services/api/schema.ts
|
||||||
static/
|
static/
|
||||||
src/theme/css/overlayscrollbars.css
|
src/theme/css/overlayscrollbars.css
|
||||||
src/theme_/css/overlayscrollbars.css
|
src/theme_/css/overlayscrollbars.css
|
||||||
|
@ -70,7 +70,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
|
|||||||
step={1}
|
step={1}
|
||||||
w={numberInputWidth}
|
w={numberInputWidth}
|
||||||
defaultValue={90}
|
defaultValue={90}
|
||||||
/>
|
/>
|
||||||
</InvControl>
|
</InvControl>
|
||||||
<InvControl label="Green">
|
<InvControl label="Green">
|
||||||
<InvNumberInput
|
<InvNumberInput
|
||||||
@ -81,7 +81,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
|
|||||||
step={1}
|
step={1}
|
||||||
w={numberInputWidth}
|
w={numberInputWidth}
|
||||||
defaultValue={90}
|
defaultValue={90}
|
||||||
/>
|
/>
|
||||||
</InvControl>
|
</InvControl>
|
||||||
<InvControl label="Blue">
|
<InvControl label="Blue">
|
||||||
<InvNumberInput
|
<InvNumberInput
|
||||||
@ -92,7 +92,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
|
|||||||
step={1}
|
step={1}
|
||||||
w={numberInputWidth}
|
w={numberInputWidth}
|
||||||
defaultValue={255}
|
defaultValue={255}
|
||||||
/>
|
/>
|
||||||
</InvControl>
|
</InvControl>
|
||||||
<InvControl label="Alpha">
|
<InvControl label="Alpha">
|
||||||
<InvNumberInput
|
<InvNumberInput
|
||||||
|
@ -3,9 +3,7 @@ import { stateSelector } from 'app/store/store';
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { InvControl } from 'common/components/InvControl/InvControl';
|
import { InvControl } from 'common/components/InvControl/InvControl';
|
||||||
import { InvSlider } from 'common/components/InvSlider/InvSlider';
|
import { InvSlider } from 'common/components/InvSlider/InvSlider';
|
||||||
import {
|
import { maxPromptsChanged } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
||||||
maxPromptsChanged,
|
|
||||||
} from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
@ -9,10 +9,7 @@ import { InvNumberInput } from 'common/components/InvNumberInput/InvNumberInput'
|
|||||||
import { InvSlider } from 'common/components/InvSlider/InvSlider';
|
import { InvSlider } from 'common/components/InvSlider/InvSlider';
|
||||||
import { InvText } from 'common/components/InvText/wrapper';
|
import { InvText } from 'common/components/InvText/wrapper';
|
||||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||||
import {
|
import { loraRemoved, loraWeightChanged } from 'features/lora/store/loraSlice';
|
||||||
loraRemoved,
|
|
||||||
loraWeightChanged,
|
|
||||||
} from 'features/lora/store/loraSlice';
|
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { FaTrashCan } from 'react-icons/fa6';
|
import { FaTrashCan } from 'react-icons/fa6';
|
||||||
|
|
||||||
|
@ -13,12 +13,10 @@ import { ALL_BASE_MODELS } from 'services/api/constants';
|
|||||||
import type {
|
import type {
|
||||||
LoRAModelConfigEntity,
|
LoRAModelConfigEntity,
|
||||||
MainModelConfigEntity,
|
MainModelConfigEntity,
|
||||||
OnnxModelConfigEntity,
|
|
||||||
} from 'services/api/endpoints/models';
|
} from 'services/api/endpoints/models';
|
||||||
import {
|
import {
|
||||||
useGetLoRAModelsQuery,
|
useGetLoRAModelsQuery,
|
||||||
useGetMainModelsQuery,
|
useGetMainModelsQuery,
|
||||||
useGetOnnxModelsQuery,
|
|
||||||
} from 'services/api/endpoints/models';
|
} from 'services/api/endpoints/models';
|
||||||
|
|
||||||
import ModelListItem from './ModelListItem';
|
import ModelListItem from './ModelListItem';
|
||||||
@ -28,9 +26,9 @@ type ModelListProps = {
|
|||||||
setSelectedModelId: (name: string | undefined) => void;
|
setSelectedModelId: (name: string | undefined) => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
type ModelFormat = 'all' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx';
|
type ModelFormat = 'all' | 'checkpoint' | 'diffusers';
|
||||||
|
|
||||||
type ModelType = 'main' | 'lora' | 'onnx';
|
type ModelType = 'main' | 'lora';
|
||||||
|
|
||||||
type CombinedModelFormat = ModelFormat | 'lora';
|
type CombinedModelFormat = ModelFormat | 'lora';
|
||||||
|
|
||||||
@ -77,26 +75,6 @@ const ModelList = (props: ModelListProps) => {
|
|||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
const { filteredOnnxModels, isLoadingOnnxModels } = useGetOnnxModelsQuery(
|
|
||||||
ALL_BASE_MODELS,
|
|
||||||
{
|
|
||||||
selectFromResult: ({ data, isLoading }) => ({
|
|
||||||
filteredOnnxModels: modelsFilter(data, 'onnx', 'onnx', nameFilter),
|
|
||||||
isLoadingOnnxModels: isLoading,
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const { filteredOliveModels, isLoadingOliveModels } = useGetOnnxModelsQuery(
|
|
||||||
ALL_BASE_MODELS,
|
|
||||||
{
|
|
||||||
selectFromResult: ({ data, isLoading }) => ({
|
|
||||||
filteredOliveModels: modelsFilter(data, 'onnx', 'olive', nameFilter),
|
|
||||||
isLoadingOliveModels: isLoading,
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||||
setNameFilter(e.target.value);
|
setNameFilter(e.target.value);
|
||||||
}, []);
|
}, []);
|
||||||
@ -126,20 +104,6 @@ const ModelList = (props: ModelListProps) => {
|
|||||||
>
|
>
|
||||||
{t('modelManager.checkpointModels')}
|
{t('modelManager.checkpointModels')}
|
||||||
</InvButton>
|
</InvButton>
|
||||||
<InvButton
|
|
||||||
size="sm"
|
|
||||||
onClick={setModelFormatFilter.bind(null, 'onnx')}
|
|
||||||
isChecked={modelFormatFilter === 'onnx'}
|
|
||||||
>
|
|
||||||
{t('modelManager.onnxModels')}
|
|
||||||
</InvButton>
|
|
||||||
<InvButton
|
|
||||||
size="sm"
|
|
||||||
onClick={setModelFormatFilter.bind(null, 'olive')}
|
|
||||||
isChecked={modelFormatFilter === 'olive'}
|
|
||||||
>
|
|
||||||
{t('modelManager.oliveModels')}
|
|
||||||
</InvButton>
|
|
||||||
<InvButton
|
<InvButton
|
||||||
size="sm"
|
size="sm"
|
||||||
onClick={setModelFormatFilter.bind(null, 'lora')}
|
onClick={setModelFormatFilter.bind(null, 'lora')}
|
||||||
@ -202,34 +166,6 @@ const ModelList = (props: ModelListProps) => {
|
|||||||
key="loras"
|
key="loras"
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{/* Olive List */}
|
|
||||||
{isLoadingOliveModels && (
|
|
||||||
<FetchingModelsLoader loadingMessage="Loading Olives..." />
|
|
||||||
)}
|
|
||||||
{['all', 'olive'].includes(modelFormatFilter) &&
|
|
||||||
!isLoadingOliveModels &&
|
|
||||||
filteredOliveModels.length > 0 && (
|
|
||||||
<ModelListWrapper
|
|
||||||
title="Olives"
|
|
||||||
modelList={filteredOliveModels}
|
|
||||||
selected={{ selectedModelId, setSelectedModelId }}
|
|
||||||
key="olive"
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{/* Onnx List */}
|
|
||||||
{isLoadingOnnxModels && (
|
|
||||||
<FetchingModelsLoader loadingMessage="Loading ONNX..." />
|
|
||||||
)}
|
|
||||||
{['all', 'onnx'].includes(modelFormatFilter) &&
|
|
||||||
!isLoadingOnnxModels &&
|
|
||||||
filteredOnnxModels.length > 0 && (
|
|
||||||
<ModelListWrapper
|
|
||||||
title="ONNX"
|
|
||||||
modelList={filteredOnnxModels}
|
|
||||||
selected={{ selectedModelId, setSelectedModelId }}
|
|
||||||
key="onnx"
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
@ -238,12 +174,7 @@ const ModelList = (props: ModelListProps) => {
|
|||||||
|
|
||||||
export default memo(ModelList);
|
export default memo(ModelList);
|
||||||
|
|
||||||
const modelsFilter = <
|
const modelsFilter = <T extends MainModelConfigEntity | LoRAModelConfigEntity>(
|
||||||
T extends
|
|
||||||
| MainModelConfigEntity
|
|
||||||
| LoRAModelConfigEntity
|
|
||||||
| OnnxModelConfigEntity,
|
|
||||||
>(
|
|
||||||
data: EntityState<T, string> | undefined,
|
data: EntityState<T, string> | undefined,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
model_format: ModelFormat | undefined,
|
model_format: ModelFormat | undefined,
|
||||||
@ -282,10 +213,7 @@ StyledModelContainer.displayName = 'StyledModelContainer';
|
|||||||
|
|
||||||
type ModelListWrapperProps = {
|
type ModelListWrapperProps = {
|
||||||
title: string;
|
title: string;
|
||||||
modelList:
|
modelList: MainModelConfigEntity[] | LoRAModelConfigEntity[];
|
||||||
| MainModelConfigEntity[]
|
|
||||||
| LoRAModelConfigEntity[]
|
|
||||||
| OnnxModelConfigEntity[];
|
|
||||||
selected: ModelListProps;
|
selected: ModelListProps;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -14,7 +14,6 @@ import { useTranslation } from 'react-i18next';
|
|||||||
import type {
|
import type {
|
||||||
LoRAModelConfigEntity,
|
LoRAModelConfigEntity,
|
||||||
MainModelConfigEntity,
|
MainModelConfigEntity,
|
||||||
OnnxModelConfigEntity,
|
|
||||||
} from 'services/api/endpoints/models';
|
} from 'services/api/endpoints/models';
|
||||||
import {
|
import {
|
||||||
useDeleteLoRAModelsMutation,
|
useDeleteLoRAModelsMutation,
|
||||||
@ -22,7 +21,7 @@ import {
|
|||||||
} from 'services/api/endpoints/models';
|
} from 'services/api/endpoints/models';
|
||||||
|
|
||||||
type ModelListItemProps = {
|
type ModelListItemProps = {
|
||||||
model: MainModelConfigEntity | OnnxModelConfigEntity | LoRAModelConfigEntity;
|
model: MainModelConfigEntity | LoRAModelConfigEntity;
|
||||||
isSelected: boolean;
|
isSelected: boolean;
|
||||||
setSelectedModelId: (v: string | undefined) => void;
|
setSelectedModelId: (v: string | undefined) => void;
|
||||||
};
|
};
|
||||||
@ -44,7 +43,6 @@ const ModelListItem = (props: ModelListItemProps) => {
|
|||||||
const method = {
|
const method = {
|
||||||
main: deleteMainModel,
|
main: deleteMainModel,
|
||||||
lora: deleteLoRAModel,
|
lora: deleteLoRAModel,
|
||||||
onnx: deleteMainModel,
|
|
||||||
}[model.model_type];
|
}[model.model_type];
|
||||||
|
|
||||||
method(model)
|
method(model)
|
||||||
|
@ -45,6 +45,7 @@ const MainModelFieldInputComponent = (props: Props) => {
|
|||||||
modelEntities: data,
|
modelEntities: data,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
isLoading,
|
isLoading,
|
||||||
|
selectedModel: field.value,
|
||||||
});
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -45,6 +45,7 @@ const RefinerModelFieldInputComponent = (props: Props) => {
|
|||||||
modelEntities: data,
|
modelEntities: data,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
isLoading,
|
isLoading,
|
||||||
|
selectedModel: field.value,
|
||||||
});
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -45,6 +45,7 @@ const SDXLMainModelFieldInputComponent = (props: Props) => {
|
|||||||
modelEntities: data,
|
modelEntities: data,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
isLoading,
|
isLoading,
|
||||||
|
selectedModel: field.value,
|
||||||
});
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -137,10 +137,11 @@ const fieldValueReducer = <T extends FieldValue>(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const input = node.data?.inputs[fieldName];
|
const input = node.data?.inputs[fieldName];
|
||||||
if (!input || nodeIndex < 0 || !schema.safeParse(value).success) {
|
const result = schema.safeParse(value);
|
||||||
|
if (!input || nodeIndex < 0 || !result.success) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
input.value = value;
|
input.value = result.data;
|
||||||
};
|
};
|
||||||
|
|
||||||
const nodesSlice = createSlice({
|
const nodesSlice = createSlice({
|
||||||
|
@ -59,7 +59,6 @@ export const zBaseModel = z.enum([
|
|||||||
'sdxl-refiner',
|
'sdxl-refiner',
|
||||||
]);
|
]);
|
||||||
export const zModelType = z.enum([
|
export const zModelType = z.enum([
|
||||||
'onnx',
|
|
||||||
'main',
|
'main',
|
||||||
'vae',
|
'vae',
|
||||||
'lora',
|
'lora',
|
||||||
@ -80,23 +79,12 @@ export const zMainModelField = z.object({
|
|||||||
base_model: zBaseModel,
|
base_model: zBaseModel,
|
||||||
model_type: z.literal('main'),
|
model_type: z.literal('main'),
|
||||||
});
|
});
|
||||||
export const zONNXModelField = z.object({
|
|
||||||
model_name: zModelName,
|
|
||||||
base_model: zBaseModel,
|
|
||||||
model_type: z.literal('onnx'),
|
|
||||||
});
|
|
||||||
export const zMainOrONNXModelField = z.union([
|
|
||||||
zMainModelField,
|
|
||||||
zONNXModelField,
|
|
||||||
]);
|
|
||||||
export const zSDXLRefinerModelField = z.object({
|
export const zSDXLRefinerModelField = z.object({
|
||||||
model_name: z.string().min(1),
|
model_name: z.string().min(1),
|
||||||
base_model: z.literal('sdxl-refiner'),
|
base_model: z.literal('sdxl-refiner'),
|
||||||
model_type: z.literal('main'),
|
model_type: z.literal('main'),
|
||||||
});
|
});
|
||||||
export type MainModelField = z.infer<typeof zMainModelField>;
|
export type MainModelField = z.infer<typeof zMainModelField>;
|
||||||
export type ONNXModelField = z.infer<typeof zONNXModelField>;
|
|
||||||
export type MainOrONNXModelField = z.infer<typeof zMainOrONNXModelField>;
|
|
||||||
export type SDXLRefinerModelField = z.infer<typeof zSDXLRefinerModelField>;
|
export type SDXLRefinerModelField = z.infer<typeof zSDXLRefinerModelField>;
|
||||||
|
|
||||||
export const zSubModelType = z.enum([
|
export const zSubModelType = z.enum([
|
||||||
|
@ -39,7 +39,6 @@ export const MODEL_TYPES = [
|
|||||||
'ControlNetModelField',
|
'ControlNetModelField',
|
||||||
'LoRAModelField',
|
'LoRAModelField',
|
||||||
'MainModelField',
|
'MainModelField',
|
||||||
'ONNXModelField',
|
|
||||||
'SDXLMainModelField',
|
'SDXLMainModelField',
|
||||||
'SDXLRefinerModelField',
|
'SDXLRefinerModelField',
|
||||||
'VaeModelField',
|
'VaeModelField',
|
||||||
@ -70,7 +69,6 @@ export const FIELD_COLORS: { [key: string]: string } = {
|
|||||||
LatentsField: 'pink.500',
|
LatentsField: 'pink.500',
|
||||||
LoRAModelField: 'teal.500',
|
LoRAModelField: 'teal.500',
|
||||||
MainModelField: 'teal.500',
|
MainModelField: 'teal.500',
|
||||||
ONNXModelField: 'teal.500',
|
|
||||||
SDXLMainModelField: 'teal.500',
|
SDXLMainModelField: 'teal.500',
|
||||||
SDXLRefinerModelField: 'teal.500',
|
SDXLRefinerModelField: 'teal.500',
|
||||||
StringField: 'yellow.500',
|
StringField: 'yellow.500',
|
||||||
|
@ -7,7 +7,7 @@ import {
|
|||||||
zImageField,
|
zImageField,
|
||||||
zIPAdapterModelField,
|
zIPAdapterModelField,
|
||||||
zLoRAModelField,
|
zLoRAModelField,
|
||||||
zMainOrONNXModelField,
|
zMainModelField,
|
||||||
zSchedulerField,
|
zSchedulerField,
|
||||||
zT2IAdapterModelField,
|
zT2IAdapterModelField,
|
||||||
zVAEModelField,
|
zVAEModelField,
|
||||||
@ -430,7 +430,7 @@ export const isColorFieldInputTemplate = (
|
|||||||
export const zMainModelFieldType = zFieldTypeBase.extend({
|
export const zMainModelFieldType = zFieldTypeBase.extend({
|
||||||
name: z.literal('MainModelField'),
|
name: z.literal('MainModelField'),
|
||||||
});
|
});
|
||||||
export const zMainModelFieldValue = zMainOrONNXModelField.optional();
|
export const zMainModelFieldValue = zMainModelField.optional();
|
||||||
export const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
export const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||||
type: zMainModelFieldType,
|
type: zMainModelFieldType,
|
||||||
value: zMainModelFieldValue,
|
value: zMainModelFieldValue,
|
||||||
|
@ -5,7 +5,6 @@ import {
|
|||||||
zIPAdapterField,
|
zIPAdapterField,
|
||||||
zLoRAModelField,
|
zLoRAModelField,
|
||||||
zMainModelField,
|
zMainModelField,
|
||||||
zONNXModelField,
|
|
||||||
zSDXLRefinerModelField,
|
zSDXLRefinerModelField,
|
||||||
zT2IAdapterField,
|
zT2IAdapterField,
|
||||||
zVAEModelField,
|
zVAEModelField,
|
||||||
@ -23,10 +22,7 @@ const zControlNetMetadataItem = zControlField.deepPartial();
|
|||||||
const zIPAdapterMetadataItem = zIPAdapterField.deepPartial();
|
const zIPAdapterMetadataItem = zIPAdapterField.deepPartial();
|
||||||
const zT2IAdapterMetadataItem = zT2IAdapterField.deepPartial();
|
const zT2IAdapterMetadataItem = zT2IAdapterField.deepPartial();
|
||||||
const zSDXLRefinerModelMetadataItem = zSDXLRefinerModelField.deepPartial();
|
const zSDXLRefinerModelMetadataItem = zSDXLRefinerModelField.deepPartial();
|
||||||
const zModelMetadataItem = z.union([
|
const zModelMetadataItem = zMainModelField.deepPartial();
|
||||||
zMainModelField.deepPartial(),
|
|
||||||
zONNXModelField.deepPartial(),
|
|
||||||
]);
|
|
||||||
const zVAEModelMetadataItem = zVAEModelField.deepPartial();
|
const zVAEModelMetadataItem = zVAEModelField.deepPartial();
|
||||||
export type LoRAMetadataItem = z.infer<typeof zLoRAMetadataItem>;
|
export type LoRAMetadataItem = z.infer<typeof zLoRAMetadataItem>;
|
||||||
export type ControlNetMetadataItem = z.infer<typeof zControlNetMetadataItem>;
|
export type ControlNetMetadataItem = z.infer<typeof zControlNetMetadataItem>;
|
||||||
|
@ -273,11 +273,6 @@ const FIELD_TYPE_V1_TO_STATELESS_FIELD_TYPE_V2: {
|
|||||||
isCollection: false,
|
isCollection: false,
|
||||||
isCollectionOrScalar: true,
|
isCollectionOrScalar: true,
|
||||||
},
|
},
|
||||||
ONNXModelField: {
|
|
||||||
name: 'ONNXModelField',
|
|
||||||
isCollection: false,
|
|
||||||
isCollectionOrScalar: false,
|
|
||||||
},
|
|
||||||
T2IAdapterField: {
|
T2IAdapterField: {
|
||||||
name: 'T2IAdapterField',
|
name: 'T2IAdapterField',
|
||||||
isCollection: false,
|
isCollection: false,
|
||||||
|
@ -14,7 +14,6 @@ import {
|
|||||||
INPAINT_IMAGE,
|
INPAINT_IMAGE,
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
MAIN_MODEL_LOADER,
|
MAIN_MODEL_LOADER,
|
||||||
ONNX_MODEL_LOADER,
|
|
||||||
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
||||||
SDXL_CANVAS_INPAINT_GRAPH,
|
SDXL_CANVAS_INPAINT_GRAPH,
|
||||||
SDXL_CANVAS_OUTPAINT_GRAPH,
|
SDXL_CANVAS_OUTPAINT_GRAPH,
|
||||||
@ -50,7 +49,6 @@ export const addVAEToGraph = (
|
|||||||
vae_model: vae,
|
vae_model: vae,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
const isOnnxModel = modelLoaderNodeId == ONNX_MODEL_LOADER;
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
graph.id === TEXT_TO_IMAGE_GRAPH ||
|
graph.id === TEXT_TO_IMAGE_GRAPH ||
|
||||||
@ -61,7 +59,7 @@ export const addVAEToGraph = (
|
|||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: {
|
source: {
|
||||||
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
||||||
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
|
field: 'vae',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
node_id: LATENTS_TO_IMAGE,
|
node_id: LATENTS_TO_IMAGE,
|
||||||
@ -79,7 +77,7 @@ export const addVAEToGraph = (
|
|||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: {
|
source: {
|
||||||
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
||||||
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
|
field: 'vae',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
node_id: isUsingScaledDimensions ? LATENTS_TO_IMAGE : CANVAS_OUTPUT,
|
node_id: isUsingScaledDimensions ? LATENTS_TO_IMAGE : CANVAS_OUTPUT,
|
||||||
@ -97,7 +95,7 @@ export const addVAEToGraph = (
|
|||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: {
|
source: {
|
||||||
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
||||||
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
|
field: 'vae',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
node_id: IMAGE_TO_LATENTS,
|
node_id: IMAGE_TO_LATENTS,
|
||||||
@ -116,7 +114,7 @@ export const addVAEToGraph = (
|
|||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
||||||
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
|
field: 'vae',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
node_id: INPAINT_IMAGE,
|
node_id: INPAINT_IMAGE,
|
||||||
@ -126,7 +124,7 @@ export const addVAEToGraph = (
|
|||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
||||||
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
|
field: 'vae',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
node_id: INPAINT_CREATE_MASK,
|
node_id: INPAINT_CREATE_MASK,
|
||||||
@ -136,7 +134,7 @@ export const addVAEToGraph = (
|
|||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
||||||
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
|
field: 'vae',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
node_id: LATENTS_TO_IMAGE,
|
node_id: LATENTS_TO_IMAGE,
|
||||||
@ -150,7 +148,7 @@ export const addVAEToGraph = (
|
|||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: {
|
source: {
|
||||||
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
||||||
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
|
field: 'vae',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
@ -168,7 +166,7 @@ export const addVAEToGraph = (
|
|||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: {
|
source: {
|
||||||
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
||||||
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
|
field: 'vae',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||||
|
@ -85,7 +85,6 @@ export const buildCanvasSDXLTextToImageGraph = (
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||||
// TODO: Actually create the graph correctly for ONNX
|
|
||||||
const graph: NonNullableGraph = {
|
const graph: NonNullableGraph = {
|
||||||
id: SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
|
id: SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
|
||||||
nodes: {
|
nodes: {
|
||||||
|
@ -78,7 +78,6 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||||
// TODO: Actually create the graph correctly for ONNX
|
|
||||||
const graph: NonNullableGraph = {
|
const graph: NonNullableGraph = {
|
||||||
id: CANVAS_TEXT_TO_IMAGE_GRAPH,
|
id: CANVAS_TEXT_TO_IMAGE_GRAPH,
|
||||||
nodes: {
|
nodes: {
|
||||||
|
@ -70,7 +70,6 @@ export const buildLinearTextToImageGraph = (
|
|||||||
|
|
||||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||||
|
|
||||||
// TODO: Actually create the graph correctly for ONNX
|
|
||||||
const graph: NonNullableGraph = {
|
const graph: NonNullableGraph = {
|
||||||
id: TEXT_TO_IMAGE_GRAPH,
|
id: TEXT_TO_IMAGE_GRAPH,
|
||||||
nodes: {
|
nodes: {
|
||||||
|
@ -18,7 +18,6 @@ export const RANDOM_INT = 'rand_int';
|
|||||||
export const RANGE_OF_SIZE = 'range_of_size';
|
export const RANGE_OF_SIZE = 'range_of_size';
|
||||||
export const ITERATE = 'iterate';
|
export const ITERATE = 'iterate';
|
||||||
export const MAIN_MODEL_LOADER = 'main_model_loader';
|
export const MAIN_MODEL_LOADER = 'main_model_loader';
|
||||||
export const ONNX_MODEL_LOADER = 'onnx_model_loader';
|
|
||||||
export const VAE_LOADER = 'vae_loader';
|
export const VAE_LOADER = 'vae_loader';
|
||||||
export const LORA_LOADER = 'lora_loader';
|
export const LORA_LOADER = 'lora_loader';
|
||||||
export const CLIP_SKIP = 'clip_skip';
|
export const CLIP_SKIP = 'clip_skip';
|
||||||
|
@ -24,7 +24,14 @@ const RESERVED_INPUT_FIELD_NAMES = ['id', 'type', 'use_cache'];
|
|||||||
const RESERVED_OUTPUT_FIELD_NAMES = ['type'];
|
const RESERVED_OUTPUT_FIELD_NAMES = ['type'];
|
||||||
const RESERVED_FIELD_TYPES = ['IsIntermediate'];
|
const RESERVED_FIELD_TYPES = ['IsIntermediate'];
|
||||||
|
|
||||||
const invocationDenylist: string[] = ['graph', 'linear_ui_output'];
|
const invocationDenylist: string[] = [
|
||||||
|
'graph',
|
||||||
|
'linear_ui_output',
|
||||||
|
'l2i_onnx',
|
||||||
|
'prompt_onnx',
|
||||||
|
't2l_onnx',
|
||||||
|
'onnx_model_loader',
|
||||||
|
];
|
||||||
|
|
||||||
const isReservedInputField = (nodeType: string, fieldName: string) => {
|
const isReservedInputField = (nodeType: string, fieldName: string) => {
|
||||||
if (RESERVED_INPUT_FIELD_NAMES.includes(fieldName)) {
|
if (RESERVED_INPUT_FIELD_NAMES.includes(fieldName)) {
|
||||||
|
@ -1,14 +1,10 @@
|
|||||||
import { createAction } from '@reduxjs/toolkit';
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
import type {
|
import type { ImageDTO, MainModelField } from 'services/api/types';
|
||||||
ImageDTO,
|
|
||||||
MainModelField,
|
|
||||||
OnnxModelField,
|
|
||||||
} from 'services/api/types';
|
|
||||||
|
|
||||||
export const initialImageSelected = createAction<ImageDTO | undefined>(
|
export const initialImageSelected = createAction<ImageDTO | undefined>(
|
||||||
'generation/initialImageSelected'
|
'generation/initialImageSelected'
|
||||||
);
|
);
|
||||||
|
|
||||||
export const modelSelected = createAction<MainModelField | OnnxModelField>(
|
export const modelSelected = createAction<MainModelField>(
|
||||||
'generation/modelSelected'
|
'generation/modelSelected'
|
||||||
);
|
);
|
||||||
|
@ -15,7 +15,6 @@ import type {
|
|||||||
MainModelConfig,
|
MainModelConfig,
|
||||||
MergeModelConfig,
|
MergeModelConfig,
|
||||||
ModelType,
|
ModelType,
|
||||||
OnnxModelConfig,
|
|
||||||
T2IAdapterModelConfig,
|
T2IAdapterModelConfig,
|
||||||
TextualInversionModelConfig,
|
TextualInversionModelConfig,
|
||||||
VaeModelConfig,
|
VaeModelConfig,
|
||||||
@ -32,8 +31,6 @@ export type MainModelConfigEntity =
|
|||||||
| DiffusersModelConfigEntity
|
| DiffusersModelConfigEntity
|
||||||
| CheckpointModelConfigEntity;
|
| CheckpointModelConfigEntity;
|
||||||
|
|
||||||
export type OnnxModelConfigEntity = OnnxModelConfig & { id: string };
|
|
||||||
|
|
||||||
export type LoRAModelConfigEntity = LoRAModelConfig & { id: string };
|
export type LoRAModelConfigEntity = LoRAModelConfig & { id: string };
|
||||||
|
|
||||||
export type ControlNetModelConfigEntity = ControlNetModelConfig & {
|
export type ControlNetModelConfigEntity = ControlNetModelConfig & {
|
||||||
@ -56,7 +53,6 @@ export type VaeModelConfigEntity = VaeModelConfig & { id: string };
|
|||||||
|
|
||||||
export type AnyModelConfigEntity =
|
export type AnyModelConfigEntity =
|
||||||
| MainModelConfigEntity
|
| MainModelConfigEntity
|
||||||
| OnnxModelConfigEntity
|
|
||||||
| LoRAModelConfigEntity
|
| LoRAModelConfigEntity
|
||||||
| ControlNetModelConfigEntity
|
| ControlNetModelConfigEntity
|
||||||
| IPAdapterModelConfigEntity
|
| IPAdapterModelConfigEntity
|
||||||
@ -138,9 +134,6 @@ type SearchFolderArg = operations['search_for_models']['parameters']['query'];
|
|||||||
export const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
export const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
||||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||||
});
|
});
|
||||||
const onnxModelsAdapter = createEntityAdapter<OnnxModelConfigEntity>({
|
|
||||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
|
||||||
});
|
|
||||||
export const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
|
export const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
|
||||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||||
});
|
});
|
||||||
@ -187,46 +180,6 @@ const createModelEntities = <T extends AnyModelConfigEntity>(
|
|||||||
|
|
||||||
export const modelsApi = api.injectEndpoints({
|
export const modelsApi = api.injectEndpoints({
|
||||||
endpoints: (build) => ({
|
endpoints: (build) => ({
|
||||||
getOnnxModels: build.query<
|
|
||||||
EntityState<OnnxModelConfigEntity, string>,
|
|
||||||
BaseModelType[]
|
|
||||||
>({
|
|
||||||
query: (base_models) => {
|
|
||||||
const params = {
|
|
||||||
model_type: 'onnx',
|
|
||||||
base_models,
|
|
||||||
};
|
|
||||||
|
|
||||||
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
|
||||||
return `models/?${query}`;
|
|
||||||
},
|
|
||||||
providesTags: (result) => {
|
|
||||||
const tags: ApiTagDescription[] = [
|
|
||||||
{ type: 'OnnxModel', id: LIST_TAG },
|
|
||||||
'Model',
|
|
||||||
];
|
|
||||||
|
|
||||||
if (result) {
|
|
||||||
tags.push(
|
|
||||||
...result.ids.map((id) => ({
|
|
||||||
type: 'OnnxModel' as const,
|
|
||||||
id,
|
|
||||||
}))
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return tags;
|
|
||||||
},
|
|
||||||
transformResponse: (response: { models: OnnxModelConfig[] }) => {
|
|
||||||
const entities = createModelEntities<OnnxModelConfigEntity>(
|
|
||||||
response.models
|
|
||||||
);
|
|
||||||
return onnxModelsAdapter.setAll(
|
|
||||||
onnxModelsAdapter.getInitialState(),
|
|
||||||
entities
|
|
||||||
);
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
getMainModels: build.query<
|
getMainModels: build.query<
|
||||||
EntityState<MainModelConfigEntity, string>,
|
EntityState<MainModelConfigEntity, string>,
|
||||||
BaseModelType[]
|
BaseModelType[]
|
||||||
@ -583,7 +536,6 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
|
|
||||||
export const {
|
export const {
|
||||||
useGetMainModelsQuery,
|
useGetMainModelsQuery,
|
||||||
useGetOnnxModelsQuery,
|
|
||||||
useGetControlNetModelsQuery,
|
useGetControlNetModelsQuery,
|
||||||
useGetIPAdapterModelsQuery,
|
useGetIPAdapterModelsQuery,
|
||||||
useGetT2IAdapterModelsQuery,
|
useGetT2IAdapterModelsQuery,
|
||||||
|
@ -32,7 +32,6 @@ export const tagTypes = [
|
|||||||
'Model',
|
'Model',
|
||||||
'T2IAdapterModel',
|
'T2IAdapterModel',
|
||||||
'MainModel',
|
'MainModel',
|
||||||
'OnnxModel',
|
|
||||||
'VaeModel',
|
'VaeModel',
|
||||||
'IPAdapterModel',
|
'IPAdapterModel',
|
||||||
'TextualInversionModel',
|
'TextualInversionModel',
|
||||||
|
File diff suppressed because one or more lines are too long
@ -56,7 +56,6 @@ export type SubModelType = s['SubModelType'];
|
|||||||
export type BaseModelType =
|
export type BaseModelType =
|
||||||
s['invokeai__backend__model_management__models__base__BaseModelType'];
|
s['invokeai__backend__model_management__models__base__BaseModelType'];
|
||||||
export type MainModelField = s['MainModelField'];
|
export type MainModelField = s['MainModelField'];
|
||||||
export type OnnxModelField = s['OnnxModelField'];
|
|
||||||
export type VAEModelField = s['VAEModelField'];
|
export type VAEModelField = s['VAEModelField'];
|
||||||
export type LoRAModelField = s['LoRAModelField'];
|
export type LoRAModelField = s['LoRAModelField'];
|
||||||
export type LoRAModelFormat = s['LoRAModelFormat'];
|
export type LoRAModelFormat = s['LoRAModelFormat'];
|
||||||
@ -91,7 +90,6 @@ export type CheckpointModelConfig =
|
|||||||
| s['StableDiffusion1ModelCheckpointConfig']
|
| s['StableDiffusion1ModelCheckpointConfig']
|
||||||
| s['StableDiffusion2ModelCheckpointConfig']
|
| s['StableDiffusion2ModelCheckpointConfig']
|
||||||
| s['StableDiffusionXLModelCheckpointConfig'];
|
| s['StableDiffusionXLModelCheckpointConfig'];
|
||||||
export type OnnxModelConfig = s['ONNXStableDiffusion1ModelConfig'];
|
|
||||||
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
|
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
|
||||||
export type AnyModelConfig =
|
export type AnyModelConfig =
|
||||||
| LoRAModelConfig
|
| LoRAModelConfig
|
||||||
@ -100,8 +98,7 @@ export type AnyModelConfig =
|
|||||||
| IPAdapterModelConfig
|
| IPAdapterModelConfig
|
||||||
| T2IAdapterModelConfig
|
| T2IAdapterModelConfig
|
||||||
| TextualInversionModelConfig
|
| TextualInversionModelConfig
|
||||||
| MainModelConfig
|
| MainModelConfig;
|
||||||
| OnnxModelConfig;
|
|
||||||
|
|
||||||
export type MergeModelConfig = s['Body_merge_models'];
|
export type MergeModelConfig = s['Body_merge_models'];
|
||||||
export type ImportModelConfig = s['Body_import_model'];
|
export type ImportModelConfig = s['Body_import_model'];
|
||||||
@ -137,13 +134,11 @@ export type CompelInvocation = s['CompelInvocation'];
|
|||||||
export type DynamicPromptInvocation = s['DynamicPromptInvocation'];
|
export type DynamicPromptInvocation = s['DynamicPromptInvocation'];
|
||||||
export type NoiseInvocation = s['NoiseInvocation'];
|
export type NoiseInvocation = s['NoiseInvocation'];
|
||||||
export type DenoiseLatentsInvocation = s['DenoiseLatentsInvocation'];
|
export type DenoiseLatentsInvocation = s['DenoiseLatentsInvocation'];
|
||||||
export type ONNXTextToLatentsInvocation = s['ONNXTextToLatentsInvocation'];
|
|
||||||
export type SDXLLoraLoaderInvocation = s['SDXLLoraLoaderInvocation'];
|
export type SDXLLoraLoaderInvocation = s['SDXLLoraLoaderInvocation'];
|
||||||
export type ImageToLatentsInvocation = s['ImageToLatentsInvocation'];
|
export type ImageToLatentsInvocation = s['ImageToLatentsInvocation'];
|
||||||
export type LatentsToImageInvocation = s['LatentsToImageInvocation'];
|
export type LatentsToImageInvocation = s['LatentsToImageInvocation'];
|
||||||
export type ImageCollectionInvocation = s['ImageCollectionInvocation'];
|
export type ImageCollectionInvocation = s['ImageCollectionInvocation'];
|
||||||
export type MainModelLoaderInvocation = s['MainModelLoaderInvocation'];
|
export type MainModelLoaderInvocation = s['MainModelLoaderInvocation'];
|
||||||
export type OnnxModelLoaderInvocation = s['OnnxModelLoaderInvocation'];
|
|
||||||
export type LoraLoaderInvocation = s['LoraLoaderInvocation'];
|
export type LoraLoaderInvocation = s['LoraLoaderInvocation'];
|
||||||
export type ESRGANInvocation = s['ESRGANInvocation'];
|
export type ESRGANInvocation = s['ESRGANInvocation'];
|
||||||
export type DivideInvocation = s['DivideInvocation'];
|
export type DivideInvocation = s['DivideInvocation'];
|
||||||
|
Loading…
Reference in New Issue
Block a user