fix(ui): fix workflow editor model selector, excise ONNX

Ensure workflow editor model selector component gets a value

This introduced some funky type issues related to ONNX models. ONNX doesn't work anyways (unmaintained). Instead of fixing the types to work with a non-working feature, ONNX is now removed entirely from the UI.

- Remove all refs to ONNX (and Olives)
- Fix some type issues
- Add ONNX nodes to the nodes denylist (so they are not visible in UI)
- Update VAE graph helper, which still had some ONNX logic. It's a very simple change and doesn't change any logic. Just removes some conditions that were for ONNX. I tested it and nothing broke.
- Regenerate types
- Fix prettier and eslint ignores for generated types
- Lint
This commit is contained in:
psychedelicious 2024-01-03 09:01:15 +11:00
parent ebe717099e
commit b1b5c0d3b2
27 changed files with 1524 additions and 2078 deletions

View File

@ -7,4 +7,4 @@ stats.html
index.html index.html
.yarn/ .yarn/
*.scss *.scss
src/services/api/schema.d.ts src/services/api/schema.ts

View File

@ -9,7 +9,7 @@ index.html
.yarn/ .yarn/
.yalc/ .yalc/
*.scss *.scss
src/services/api/schema.d.ts src/services/api/schema.ts
static/ static/
src/theme/css/overlayscrollbars.css src/theme/css/overlayscrollbars.css
src/theme_/css/overlayscrollbars.css src/theme_/css/overlayscrollbars.css

View File

@ -70,7 +70,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
step={1} step={1}
w={numberInputWidth} w={numberInputWidth}
defaultValue={90} defaultValue={90}
/> />
</InvControl> </InvControl>
<InvControl label="Green"> <InvControl label="Green">
<InvNumberInput <InvNumberInput
@ -81,7 +81,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
step={1} step={1}
w={numberInputWidth} w={numberInputWidth}
defaultValue={90} defaultValue={90}
/> />
</InvControl> </InvControl>
<InvControl label="Blue"> <InvControl label="Blue">
<InvNumberInput <InvNumberInput
@ -92,7 +92,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
step={1} step={1}
w={numberInputWidth} w={numberInputWidth}
defaultValue={255} defaultValue={255}
/> />
</InvControl> </InvControl>
<InvControl label="Alpha"> <InvControl label="Alpha">
<InvNumberInput <InvNumberInput

View File

@ -3,9 +3,7 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InvControl } from 'common/components/InvControl/InvControl'; import { InvControl } from 'common/components/InvControl/InvControl';
import { InvSlider } from 'common/components/InvSlider/InvSlider'; import { InvSlider } from 'common/components/InvSlider/InvSlider';
import { import { maxPromptsChanged } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
maxPromptsChanged,
} from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';

View File

@ -9,10 +9,7 @@ import { InvNumberInput } from 'common/components/InvNumberInput/InvNumberInput'
import { InvSlider } from 'common/components/InvSlider/InvSlider'; import { InvSlider } from 'common/components/InvSlider/InvSlider';
import { InvText } from 'common/components/InvText/wrapper'; import { InvText } from 'common/components/InvText/wrapper';
import type { LoRA } from 'features/lora/store/loraSlice'; import type { LoRA } from 'features/lora/store/loraSlice';
import { import { loraRemoved, loraWeightChanged } from 'features/lora/store/loraSlice';
loraRemoved,
loraWeightChanged,
} from 'features/lora/store/loraSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { FaTrashCan } from 'react-icons/fa6'; import { FaTrashCan } from 'react-icons/fa6';

View File

@ -13,12 +13,10 @@ import { ALL_BASE_MODELS } from 'services/api/constants';
import type { import type {
LoRAModelConfigEntity, LoRAModelConfigEntity,
MainModelConfigEntity, MainModelConfigEntity,
OnnxModelConfigEntity,
} from 'services/api/endpoints/models'; } from 'services/api/endpoints/models';
import { import {
useGetLoRAModelsQuery, useGetLoRAModelsQuery,
useGetMainModelsQuery, useGetMainModelsQuery,
useGetOnnxModelsQuery,
} from 'services/api/endpoints/models'; } from 'services/api/endpoints/models';
import ModelListItem from './ModelListItem'; import ModelListItem from './ModelListItem';
@ -28,9 +26,9 @@ type ModelListProps = {
setSelectedModelId: (name: string | undefined) => void; setSelectedModelId: (name: string | undefined) => void;
}; };
type ModelFormat = 'all' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx'; type ModelFormat = 'all' | 'checkpoint' | 'diffusers';
type ModelType = 'main' | 'lora' | 'onnx'; type ModelType = 'main' | 'lora';
type CombinedModelFormat = ModelFormat | 'lora'; type CombinedModelFormat = ModelFormat | 'lora';
@ -77,26 +75,6 @@ const ModelList = (props: ModelListProps) => {
} }
); );
const { filteredOnnxModels, isLoadingOnnxModels } = useGetOnnxModelsQuery(
ALL_BASE_MODELS,
{
selectFromResult: ({ data, isLoading }) => ({
filteredOnnxModels: modelsFilter(data, 'onnx', 'onnx', nameFilter),
isLoadingOnnxModels: isLoading,
}),
}
);
const { filteredOliveModels, isLoadingOliveModels } = useGetOnnxModelsQuery(
ALL_BASE_MODELS,
{
selectFromResult: ({ data, isLoading }) => ({
filteredOliveModels: modelsFilter(data, 'onnx', 'olive', nameFilter),
isLoadingOliveModels: isLoading,
}),
}
);
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => { const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setNameFilter(e.target.value); setNameFilter(e.target.value);
}, []); }, []);
@ -126,20 +104,6 @@ const ModelList = (props: ModelListProps) => {
> >
{t('modelManager.checkpointModels')} {t('modelManager.checkpointModels')}
</InvButton> </InvButton>
<InvButton
size="sm"
onClick={setModelFormatFilter.bind(null, 'onnx')}
isChecked={modelFormatFilter === 'onnx'}
>
{t('modelManager.onnxModels')}
</InvButton>
<InvButton
size="sm"
onClick={setModelFormatFilter.bind(null, 'olive')}
isChecked={modelFormatFilter === 'olive'}
>
{t('modelManager.oliveModels')}
</InvButton>
<InvButton <InvButton
size="sm" size="sm"
onClick={setModelFormatFilter.bind(null, 'lora')} onClick={setModelFormatFilter.bind(null, 'lora')}
@ -202,34 +166,6 @@ const ModelList = (props: ModelListProps) => {
key="loras" key="loras"
/> />
)} )}
{/* Olive List */}
{isLoadingOliveModels && (
<FetchingModelsLoader loadingMessage="Loading Olives..." />
)}
{['all', 'olive'].includes(modelFormatFilter) &&
!isLoadingOliveModels &&
filteredOliveModels.length > 0 && (
<ModelListWrapper
title="Olives"
modelList={filteredOliveModels}
selected={{ selectedModelId, setSelectedModelId }}
key="olive"
/>
)}
{/* Onnx List */}
{isLoadingOnnxModels && (
<FetchingModelsLoader loadingMessage="Loading ONNX..." />
)}
{['all', 'onnx'].includes(modelFormatFilter) &&
!isLoadingOnnxModels &&
filteredOnnxModels.length > 0 && (
<ModelListWrapper
title="ONNX"
modelList={filteredOnnxModels}
selected={{ selectedModelId, setSelectedModelId }}
key="onnx"
/>
)}
</Flex> </Flex>
</Flex> </Flex>
</Flex> </Flex>
@ -238,12 +174,7 @@ const ModelList = (props: ModelListProps) => {
export default memo(ModelList); export default memo(ModelList);
const modelsFilter = < const modelsFilter = <T extends MainModelConfigEntity | LoRAModelConfigEntity>(
T extends
| MainModelConfigEntity
| LoRAModelConfigEntity
| OnnxModelConfigEntity,
>(
data: EntityState<T, string> | undefined, data: EntityState<T, string> | undefined,
model_type: ModelType, model_type: ModelType,
model_format: ModelFormat | undefined, model_format: ModelFormat | undefined,
@ -282,10 +213,7 @@ StyledModelContainer.displayName = 'StyledModelContainer';
type ModelListWrapperProps = { type ModelListWrapperProps = {
title: string; title: string;
modelList: modelList: MainModelConfigEntity[] | LoRAModelConfigEntity[];
| MainModelConfigEntity[]
| LoRAModelConfigEntity[]
| OnnxModelConfigEntity[];
selected: ModelListProps; selected: ModelListProps;
}; };

View File

@ -14,7 +14,6 @@ import { useTranslation } from 'react-i18next';
import type { import type {
LoRAModelConfigEntity, LoRAModelConfigEntity,
MainModelConfigEntity, MainModelConfigEntity,
OnnxModelConfigEntity,
} from 'services/api/endpoints/models'; } from 'services/api/endpoints/models';
import { import {
useDeleteLoRAModelsMutation, useDeleteLoRAModelsMutation,
@ -22,7 +21,7 @@ import {
} from 'services/api/endpoints/models'; } from 'services/api/endpoints/models';
type ModelListItemProps = { type ModelListItemProps = {
model: MainModelConfigEntity | OnnxModelConfigEntity | LoRAModelConfigEntity; model: MainModelConfigEntity | LoRAModelConfigEntity;
isSelected: boolean; isSelected: boolean;
setSelectedModelId: (v: string | undefined) => void; setSelectedModelId: (v: string | undefined) => void;
}; };
@ -44,7 +43,6 @@ const ModelListItem = (props: ModelListItemProps) => {
const method = { const method = {
main: deleteMainModel, main: deleteMainModel,
lora: deleteLoRAModel, lora: deleteLoRAModel,
onnx: deleteMainModel,
}[model.model_type]; }[model.model_type];
method(model) method(model)

View File

@ -45,6 +45,7 @@ const MainModelFieldInputComponent = (props: Props) => {
modelEntities: data, modelEntities: data,
onChange: _onChange, onChange: _onChange,
isLoading, isLoading,
selectedModel: field.value,
}); });
return ( return (

View File

@ -45,6 +45,7 @@ const RefinerModelFieldInputComponent = (props: Props) => {
modelEntities: data, modelEntities: data,
onChange: _onChange, onChange: _onChange,
isLoading, isLoading,
selectedModel: field.value,
}); });
return ( return (

View File

@ -45,6 +45,7 @@ const SDXLMainModelFieldInputComponent = (props: Props) => {
modelEntities: data, modelEntities: data,
onChange: _onChange, onChange: _onChange,
isLoading, isLoading,
selectedModel: field.value,
}); });
return ( return (

View File

@ -137,10 +137,11 @@ const fieldValueReducer = <T extends FieldValue>(
return; return;
} }
const input = node.data?.inputs[fieldName]; const input = node.data?.inputs[fieldName];
if (!input || nodeIndex < 0 || !schema.safeParse(value).success) { const result = schema.safeParse(value);
if (!input || nodeIndex < 0 || !result.success) {
return; return;
} }
input.value = value; input.value = result.data;
}; };
const nodesSlice = createSlice({ const nodesSlice = createSlice({

View File

@ -59,7 +59,6 @@ export const zBaseModel = z.enum([
'sdxl-refiner', 'sdxl-refiner',
]); ]);
export const zModelType = z.enum([ export const zModelType = z.enum([
'onnx',
'main', 'main',
'vae', 'vae',
'lora', 'lora',
@ -80,23 +79,12 @@ export const zMainModelField = z.object({
base_model: zBaseModel, base_model: zBaseModel,
model_type: z.literal('main'), model_type: z.literal('main'),
}); });
export const zONNXModelField = z.object({
model_name: zModelName,
base_model: zBaseModel,
model_type: z.literal('onnx'),
});
export const zMainOrONNXModelField = z.union([
zMainModelField,
zONNXModelField,
]);
export const zSDXLRefinerModelField = z.object({ export const zSDXLRefinerModelField = z.object({
model_name: z.string().min(1), model_name: z.string().min(1),
base_model: z.literal('sdxl-refiner'), base_model: z.literal('sdxl-refiner'),
model_type: z.literal('main'), model_type: z.literal('main'),
}); });
export type MainModelField = z.infer<typeof zMainModelField>; export type MainModelField = z.infer<typeof zMainModelField>;
export type ONNXModelField = z.infer<typeof zONNXModelField>;
export type MainOrONNXModelField = z.infer<typeof zMainOrONNXModelField>;
export type SDXLRefinerModelField = z.infer<typeof zSDXLRefinerModelField>; export type SDXLRefinerModelField = z.infer<typeof zSDXLRefinerModelField>;
export const zSubModelType = z.enum([ export const zSubModelType = z.enum([

View File

@ -39,7 +39,6 @@ export const MODEL_TYPES = [
'ControlNetModelField', 'ControlNetModelField',
'LoRAModelField', 'LoRAModelField',
'MainModelField', 'MainModelField',
'ONNXModelField',
'SDXLMainModelField', 'SDXLMainModelField',
'SDXLRefinerModelField', 'SDXLRefinerModelField',
'VaeModelField', 'VaeModelField',
@ -70,7 +69,6 @@ export const FIELD_COLORS: { [key: string]: string } = {
LatentsField: 'pink.500', LatentsField: 'pink.500',
LoRAModelField: 'teal.500', LoRAModelField: 'teal.500',
MainModelField: 'teal.500', MainModelField: 'teal.500',
ONNXModelField: 'teal.500',
SDXLMainModelField: 'teal.500', SDXLMainModelField: 'teal.500',
SDXLRefinerModelField: 'teal.500', SDXLRefinerModelField: 'teal.500',
StringField: 'yellow.500', StringField: 'yellow.500',

View File

@ -7,7 +7,7 @@ import {
zImageField, zImageField,
zIPAdapterModelField, zIPAdapterModelField,
zLoRAModelField, zLoRAModelField,
zMainOrONNXModelField, zMainModelField,
zSchedulerField, zSchedulerField,
zT2IAdapterModelField, zT2IAdapterModelField,
zVAEModelField, zVAEModelField,
@ -430,7 +430,7 @@ export const isColorFieldInputTemplate = (
export const zMainModelFieldType = zFieldTypeBase.extend({ export const zMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('MainModelField'), name: z.literal('MainModelField'),
}); });
export const zMainModelFieldValue = zMainOrONNXModelField.optional(); export const zMainModelFieldValue = zMainModelField.optional();
export const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ export const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
type: zMainModelFieldType, type: zMainModelFieldType,
value: zMainModelFieldValue, value: zMainModelFieldValue,

View File

@ -5,7 +5,6 @@ import {
zIPAdapterField, zIPAdapterField,
zLoRAModelField, zLoRAModelField,
zMainModelField, zMainModelField,
zONNXModelField,
zSDXLRefinerModelField, zSDXLRefinerModelField,
zT2IAdapterField, zT2IAdapterField,
zVAEModelField, zVAEModelField,
@ -23,10 +22,7 @@ const zControlNetMetadataItem = zControlField.deepPartial();
const zIPAdapterMetadataItem = zIPAdapterField.deepPartial(); const zIPAdapterMetadataItem = zIPAdapterField.deepPartial();
const zT2IAdapterMetadataItem = zT2IAdapterField.deepPartial(); const zT2IAdapterMetadataItem = zT2IAdapterField.deepPartial();
const zSDXLRefinerModelMetadataItem = zSDXLRefinerModelField.deepPartial(); const zSDXLRefinerModelMetadataItem = zSDXLRefinerModelField.deepPartial();
const zModelMetadataItem = z.union([ const zModelMetadataItem = zMainModelField.deepPartial();
zMainModelField.deepPartial(),
zONNXModelField.deepPartial(),
]);
const zVAEModelMetadataItem = zVAEModelField.deepPartial(); const zVAEModelMetadataItem = zVAEModelField.deepPartial();
export type LoRAMetadataItem = z.infer<typeof zLoRAMetadataItem>; export type LoRAMetadataItem = z.infer<typeof zLoRAMetadataItem>;
export type ControlNetMetadataItem = z.infer<typeof zControlNetMetadataItem>; export type ControlNetMetadataItem = z.infer<typeof zControlNetMetadataItem>;

View File

@ -273,11 +273,6 @@ const FIELD_TYPE_V1_TO_STATELESS_FIELD_TYPE_V2: {
isCollection: false, isCollection: false,
isCollectionOrScalar: true, isCollectionOrScalar: true,
}, },
ONNXModelField: {
name: 'ONNXModelField',
isCollection: false,
isCollectionOrScalar: false,
},
T2IAdapterField: { T2IAdapterField: {
name: 'T2IAdapterField', name: 'T2IAdapterField',
isCollection: false, isCollection: false,

View File

@ -14,7 +14,6 @@ import {
INPAINT_IMAGE, INPAINT_IMAGE,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
MAIN_MODEL_LOADER, MAIN_MODEL_LOADER,
ONNX_MODEL_LOADER,
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH, SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
SDXL_CANVAS_INPAINT_GRAPH, SDXL_CANVAS_INPAINT_GRAPH,
SDXL_CANVAS_OUTPAINT_GRAPH, SDXL_CANVAS_OUTPAINT_GRAPH,
@ -50,7 +49,6 @@ export const addVAEToGraph = (
vae_model: vae, vae_model: vae,
}; };
} }
const isOnnxModel = modelLoaderNodeId == ONNX_MODEL_LOADER;
if ( if (
graph.id === TEXT_TO_IMAGE_GRAPH || graph.id === TEXT_TO_IMAGE_GRAPH ||
@ -61,7 +59,7 @@ export const addVAEToGraph = (
graph.edges.push({ graph.edges.push({
source: { source: {
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER, node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae', field: 'vae',
}, },
destination: { destination: {
node_id: LATENTS_TO_IMAGE, node_id: LATENTS_TO_IMAGE,
@ -79,7 +77,7 @@ export const addVAEToGraph = (
graph.edges.push({ graph.edges.push({
source: { source: {
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER, node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae', field: 'vae',
}, },
destination: { destination: {
node_id: isUsingScaledDimensions ? LATENTS_TO_IMAGE : CANVAS_OUTPUT, node_id: isUsingScaledDimensions ? LATENTS_TO_IMAGE : CANVAS_OUTPUT,
@ -97,7 +95,7 @@ export const addVAEToGraph = (
graph.edges.push({ graph.edges.push({
source: { source: {
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER, node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae', field: 'vae',
}, },
destination: { destination: {
node_id: IMAGE_TO_LATENTS, node_id: IMAGE_TO_LATENTS,
@ -116,7 +114,7 @@ export const addVAEToGraph = (
{ {
source: { source: {
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER, node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae', field: 'vae',
}, },
destination: { destination: {
node_id: INPAINT_IMAGE, node_id: INPAINT_IMAGE,
@ -126,7 +124,7 @@ export const addVAEToGraph = (
{ {
source: { source: {
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER, node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae', field: 'vae',
}, },
destination: { destination: {
node_id: INPAINT_CREATE_MASK, node_id: INPAINT_CREATE_MASK,
@ -136,7 +134,7 @@ export const addVAEToGraph = (
{ {
source: { source: {
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER, node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae', field: 'vae',
}, },
destination: { destination: {
node_id: LATENTS_TO_IMAGE, node_id: LATENTS_TO_IMAGE,
@ -150,7 +148,7 @@ export const addVAEToGraph = (
graph.edges.push({ graph.edges.push({
source: { source: {
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER, node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae', field: 'vae',
}, },
destination: { destination: {
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK, node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
@ -168,7 +166,7 @@ export const addVAEToGraph = (
graph.edges.push({ graph.edges.push({
source: { source: {
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER, node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae', field: 'vae',
}, },
destination: { destination: {
node_id: SDXL_REFINER_INPAINT_CREATE_MASK, node_id: SDXL_REFINER_INPAINT_CREATE_MASK,

View File

@ -85,7 +85,6 @@ export const buildCanvasSDXLTextToImageGraph = (
*/ */
// copy-pasted graph from node editor, filled in with state values & friendly node ids // copy-pasted graph from node editor, filled in with state values & friendly node ids
// TODO: Actually create the graph correctly for ONNX
const graph: NonNullableGraph = { const graph: NonNullableGraph = {
id: SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH, id: SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
nodes: { nodes: {

View File

@ -78,7 +78,6 @@ export const buildCanvasTextToImageGraph = (
*/ */
// copy-pasted graph from node editor, filled in with state values & friendly node ids // copy-pasted graph from node editor, filled in with state values & friendly node ids
// TODO: Actually create the graph correctly for ONNX
const graph: NonNullableGraph = { const graph: NonNullableGraph = {
id: CANVAS_TEXT_TO_IMAGE_GRAPH, id: CANVAS_TEXT_TO_IMAGE_GRAPH,
nodes: { nodes: {

View File

@ -70,7 +70,6 @@ export const buildLinearTextToImageGraph = (
// copy-pasted graph from node editor, filled in with state values & friendly node ids // copy-pasted graph from node editor, filled in with state values & friendly node ids
// TODO: Actually create the graph correctly for ONNX
const graph: NonNullableGraph = { const graph: NonNullableGraph = {
id: TEXT_TO_IMAGE_GRAPH, id: TEXT_TO_IMAGE_GRAPH,
nodes: { nodes: {

View File

@ -18,7 +18,6 @@ export const RANDOM_INT = 'rand_int';
export const RANGE_OF_SIZE = 'range_of_size'; export const RANGE_OF_SIZE = 'range_of_size';
export const ITERATE = 'iterate'; export const ITERATE = 'iterate';
export const MAIN_MODEL_LOADER = 'main_model_loader'; export const MAIN_MODEL_LOADER = 'main_model_loader';
export const ONNX_MODEL_LOADER = 'onnx_model_loader';
export const VAE_LOADER = 'vae_loader'; export const VAE_LOADER = 'vae_loader';
export const LORA_LOADER = 'lora_loader'; export const LORA_LOADER = 'lora_loader';
export const CLIP_SKIP = 'clip_skip'; export const CLIP_SKIP = 'clip_skip';

View File

@ -24,7 +24,14 @@ const RESERVED_INPUT_FIELD_NAMES = ['id', 'type', 'use_cache'];
const RESERVED_OUTPUT_FIELD_NAMES = ['type']; const RESERVED_OUTPUT_FIELD_NAMES = ['type'];
const RESERVED_FIELD_TYPES = ['IsIntermediate']; const RESERVED_FIELD_TYPES = ['IsIntermediate'];
const invocationDenylist: string[] = ['graph', 'linear_ui_output']; const invocationDenylist: string[] = [
'graph',
'linear_ui_output',
'l2i_onnx',
'prompt_onnx',
't2l_onnx',
'onnx_model_loader',
];
const isReservedInputField = (nodeType: string, fieldName: string) => { const isReservedInputField = (nodeType: string, fieldName: string) => {
if (RESERVED_INPUT_FIELD_NAMES.includes(fieldName)) { if (RESERVED_INPUT_FIELD_NAMES.includes(fieldName)) {

View File

@ -1,14 +1,10 @@
import { createAction } from '@reduxjs/toolkit'; import { createAction } from '@reduxjs/toolkit';
import type { import type { ImageDTO, MainModelField } from 'services/api/types';
ImageDTO,
MainModelField,
OnnxModelField,
} from 'services/api/types';
export const initialImageSelected = createAction<ImageDTO | undefined>( export const initialImageSelected = createAction<ImageDTO | undefined>(
'generation/initialImageSelected' 'generation/initialImageSelected'
); );
export const modelSelected = createAction<MainModelField | OnnxModelField>( export const modelSelected = createAction<MainModelField>(
'generation/modelSelected' 'generation/modelSelected'
); );

View File

@ -15,7 +15,6 @@ import type {
MainModelConfig, MainModelConfig,
MergeModelConfig, MergeModelConfig,
ModelType, ModelType,
OnnxModelConfig,
T2IAdapterModelConfig, T2IAdapterModelConfig,
TextualInversionModelConfig, TextualInversionModelConfig,
VaeModelConfig, VaeModelConfig,
@ -32,8 +31,6 @@ export type MainModelConfigEntity =
| DiffusersModelConfigEntity | DiffusersModelConfigEntity
| CheckpointModelConfigEntity; | CheckpointModelConfigEntity;
export type OnnxModelConfigEntity = OnnxModelConfig & { id: string };
export type LoRAModelConfigEntity = LoRAModelConfig & { id: string }; export type LoRAModelConfigEntity = LoRAModelConfig & { id: string };
export type ControlNetModelConfigEntity = ControlNetModelConfig & { export type ControlNetModelConfigEntity = ControlNetModelConfig & {
@ -56,7 +53,6 @@ export type VaeModelConfigEntity = VaeModelConfig & { id: string };
export type AnyModelConfigEntity = export type AnyModelConfigEntity =
| MainModelConfigEntity | MainModelConfigEntity
| OnnxModelConfigEntity
| LoRAModelConfigEntity | LoRAModelConfigEntity
| ControlNetModelConfigEntity | ControlNetModelConfigEntity
| IPAdapterModelConfigEntity | IPAdapterModelConfigEntity
@ -138,9 +134,6 @@ type SearchFolderArg = operations['search_for_models']['parameters']['query'];
export const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({ export const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
}); });
const onnxModelsAdapter = createEntityAdapter<OnnxModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
export const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({ export const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
}); });
@ -187,46 +180,6 @@ const createModelEntities = <T extends AnyModelConfigEntity>(
export const modelsApi = api.injectEndpoints({ export const modelsApi = api.injectEndpoints({
endpoints: (build) => ({ endpoints: (build) => ({
getOnnxModels: build.query<
EntityState<OnnxModelConfigEntity, string>,
BaseModelType[]
>({
query: (base_models) => {
const params = {
model_type: 'onnx',
base_models,
};
const query = queryString.stringify(params, { arrayFormat: 'none' });
return `models/?${query}`;
},
providesTags: (result) => {
const tags: ApiTagDescription[] = [
{ type: 'OnnxModel', id: LIST_TAG },
'Model',
];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'OnnxModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (response: { models: OnnxModelConfig[] }) => {
const entities = createModelEntities<OnnxModelConfigEntity>(
response.models
);
return onnxModelsAdapter.setAll(
onnxModelsAdapter.getInitialState(),
entities
);
},
}),
getMainModels: build.query< getMainModels: build.query<
EntityState<MainModelConfigEntity, string>, EntityState<MainModelConfigEntity, string>,
BaseModelType[] BaseModelType[]
@ -583,7 +536,6 @@ export const modelsApi = api.injectEndpoints({
export const { export const {
useGetMainModelsQuery, useGetMainModelsQuery,
useGetOnnxModelsQuery,
useGetControlNetModelsQuery, useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery, useGetIPAdapterModelsQuery,
useGetT2IAdapterModelsQuery, useGetT2IAdapterModelsQuery,

View File

@ -32,7 +32,6 @@ export const tagTypes = [
'Model', 'Model',
'T2IAdapterModel', 'T2IAdapterModel',
'MainModel', 'MainModel',
'OnnxModel',
'VaeModel', 'VaeModel',
'IPAdapterModel', 'IPAdapterModel',
'TextualInversionModel', 'TextualInversionModel',

File diff suppressed because one or more lines are too long

View File

@ -56,7 +56,6 @@ export type SubModelType = s['SubModelType'];
export type BaseModelType = export type BaseModelType =
s['invokeai__backend__model_management__models__base__BaseModelType']; s['invokeai__backend__model_management__models__base__BaseModelType'];
export type MainModelField = s['MainModelField']; export type MainModelField = s['MainModelField'];
export type OnnxModelField = s['OnnxModelField'];
export type VAEModelField = s['VAEModelField']; export type VAEModelField = s['VAEModelField'];
export type LoRAModelField = s['LoRAModelField']; export type LoRAModelField = s['LoRAModelField'];
export type LoRAModelFormat = s['LoRAModelFormat']; export type LoRAModelFormat = s['LoRAModelFormat'];
@ -91,7 +90,6 @@ export type CheckpointModelConfig =
| s['StableDiffusion1ModelCheckpointConfig'] | s['StableDiffusion1ModelCheckpointConfig']
| s['StableDiffusion2ModelCheckpointConfig'] | s['StableDiffusion2ModelCheckpointConfig']
| s['StableDiffusionXLModelCheckpointConfig']; | s['StableDiffusionXLModelCheckpointConfig'];
export type OnnxModelConfig = s['ONNXStableDiffusion1ModelConfig'];
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig; export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
export type AnyModelConfig = export type AnyModelConfig =
| LoRAModelConfig | LoRAModelConfig
@ -100,8 +98,7 @@ export type AnyModelConfig =
| IPAdapterModelConfig | IPAdapterModelConfig
| T2IAdapterModelConfig | T2IAdapterModelConfig
| TextualInversionModelConfig | TextualInversionModelConfig
| MainModelConfig | MainModelConfig;
| OnnxModelConfig;
export type MergeModelConfig = s['Body_merge_models']; export type MergeModelConfig = s['Body_merge_models'];
export type ImportModelConfig = s['Body_import_model']; export type ImportModelConfig = s['Body_import_model'];
@ -137,13 +134,11 @@ export type CompelInvocation = s['CompelInvocation'];
export type DynamicPromptInvocation = s['DynamicPromptInvocation']; export type DynamicPromptInvocation = s['DynamicPromptInvocation'];
export type NoiseInvocation = s['NoiseInvocation']; export type NoiseInvocation = s['NoiseInvocation'];
export type DenoiseLatentsInvocation = s['DenoiseLatentsInvocation']; export type DenoiseLatentsInvocation = s['DenoiseLatentsInvocation'];
export type ONNXTextToLatentsInvocation = s['ONNXTextToLatentsInvocation'];
export type SDXLLoraLoaderInvocation = s['SDXLLoraLoaderInvocation']; export type SDXLLoraLoaderInvocation = s['SDXLLoraLoaderInvocation'];
export type ImageToLatentsInvocation = s['ImageToLatentsInvocation']; export type ImageToLatentsInvocation = s['ImageToLatentsInvocation'];
export type LatentsToImageInvocation = s['LatentsToImageInvocation']; export type LatentsToImageInvocation = s['LatentsToImageInvocation'];
export type ImageCollectionInvocation = s['ImageCollectionInvocation']; export type ImageCollectionInvocation = s['ImageCollectionInvocation'];
export type MainModelLoaderInvocation = s['MainModelLoaderInvocation']; export type MainModelLoaderInvocation = s['MainModelLoaderInvocation'];
export type OnnxModelLoaderInvocation = s['OnnxModelLoaderInvocation'];
export type LoraLoaderInvocation = s['LoraLoaderInvocation']; export type LoraLoaderInvocation = s['LoraLoaderInvocation'];
export type ESRGANInvocation = s['ESRGANInvocation']; export type ESRGANInvocation = s['ESRGANInvocation'];
export type DivideInvocation = s['DivideInvocation']; export type DivideInvocation = s['DivideInvocation'];