mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): fix workflow editor model selector, excise ONNX
Ensure workflow editor model selector component gets a value This introduced some funky type issues related to ONNX models. ONNX doesn't work anyways (unmaintained). Instead of fixing the types to work with a non-working feature, ONNX is now removed entirely from the UI. - Remove all refs to ONNX (and Olives) - Fix some type issues - Add ONNX nodes to the nodes denylist (so they are not visible in UI) - Update VAE graph helper, which still had some ONNX logic. It's a very simple change and doesn't change any logic. Just removes some conditions that were for ONNX. I tested it and nothing broke. - Regenerate types - Fix prettier and eslint ignores for generated types - Lint
This commit is contained in:
@ -45,6 +45,7 @@ const MainModelFieldInputComponent = (props: Props) => {
|
||||
modelEntities: data,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
|
||||
return (
|
||||
|
@ -45,6 +45,7 @@ const RefinerModelFieldInputComponent = (props: Props) => {
|
||||
modelEntities: data,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
|
||||
return (
|
||||
|
@ -45,6 +45,7 @@ const SDXLMainModelFieldInputComponent = (props: Props) => {
|
||||
modelEntities: data,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
|
||||
return (
|
||||
|
@ -137,10 +137,11 @@ const fieldValueReducer = <T extends FieldValue>(
|
||||
return;
|
||||
}
|
||||
const input = node.data?.inputs[fieldName];
|
||||
if (!input || nodeIndex < 0 || !schema.safeParse(value).success) {
|
||||
const result = schema.safeParse(value);
|
||||
if (!input || nodeIndex < 0 || !result.success) {
|
||||
return;
|
||||
}
|
||||
input.value = value;
|
||||
input.value = result.data;
|
||||
};
|
||||
|
||||
const nodesSlice = createSlice({
|
||||
|
@ -59,7 +59,6 @@ export const zBaseModel = z.enum([
|
||||
'sdxl-refiner',
|
||||
]);
|
||||
export const zModelType = z.enum([
|
||||
'onnx',
|
||||
'main',
|
||||
'vae',
|
||||
'lora',
|
||||
@ -80,23 +79,12 @@ export const zMainModelField = z.object({
|
||||
base_model: zBaseModel,
|
||||
model_type: z.literal('main'),
|
||||
});
|
||||
export const zONNXModelField = z.object({
|
||||
model_name: zModelName,
|
||||
base_model: zBaseModel,
|
||||
model_type: z.literal('onnx'),
|
||||
});
|
||||
export const zMainOrONNXModelField = z.union([
|
||||
zMainModelField,
|
||||
zONNXModelField,
|
||||
]);
|
||||
export const zSDXLRefinerModelField = z.object({
|
||||
model_name: z.string().min(1),
|
||||
base_model: z.literal('sdxl-refiner'),
|
||||
model_type: z.literal('main'),
|
||||
});
|
||||
export type MainModelField = z.infer<typeof zMainModelField>;
|
||||
export type ONNXModelField = z.infer<typeof zONNXModelField>;
|
||||
export type MainOrONNXModelField = z.infer<typeof zMainOrONNXModelField>;
|
||||
export type SDXLRefinerModelField = z.infer<typeof zSDXLRefinerModelField>;
|
||||
|
||||
export const zSubModelType = z.enum([
|
||||
|
@ -39,7 +39,6 @@ export const MODEL_TYPES = [
|
||||
'ControlNetModelField',
|
||||
'LoRAModelField',
|
||||
'MainModelField',
|
||||
'ONNXModelField',
|
||||
'SDXLMainModelField',
|
||||
'SDXLRefinerModelField',
|
||||
'VaeModelField',
|
||||
@ -70,7 +69,6 @@ export const FIELD_COLORS: { [key: string]: string } = {
|
||||
LatentsField: 'pink.500',
|
||||
LoRAModelField: 'teal.500',
|
||||
MainModelField: 'teal.500',
|
||||
ONNXModelField: 'teal.500',
|
||||
SDXLMainModelField: 'teal.500',
|
||||
SDXLRefinerModelField: 'teal.500',
|
||||
StringField: 'yellow.500',
|
||||
|
@ -7,7 +7,7 @@ import {
|
||||
zImageField,
|
||||
zIPAdapterModelField,
|
||||
zLoRAModelField,
|
||||
zMainOrONNXModelField,
|
||||
zMainModelField,
|
||||
zSchedulerField,
|
||||
zT2IAdapterModelField,
|
||||
zVAEModelField,
|
||||
@ -430,7 +430,7 @@ export const isColorFieldInputTemplate = (
|
||||
export const zMainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('MainModelField'),
|
||||
});
|
||||
export const zMainModelFieldValue = zMainOrONNXModelField.optional();
|
||||
export const zMainModelFieldValue = zMainModelField.optional();
|
||||
export const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zMainModelFieldType,
|
||||
value: zMainModelFieldValue,
|
||||
|
@ -5,7 +5,6 @@ import {
|
||||
zIPAdapterField,
|
||||
zLoRAModelField,
|
||||
zMainModelField,
|
||||
zONNXModelField,
|
||||
zSDXLRefinerModelField,
|
||||
zT2IAdapterField,
|
||||
zVAEModelField,
|
||||
@ -23,10 +22,7 @@ const zControlNetMetadataItem = zControlField.deepPartial();
|
||||
const zIPAdapterMetadataItem = zIPAdapterField.deepPartial();
|
||||
const zT2IAdapterMetadataItem = zT2IAdapterField.deepPartial();
|
||||
const zSDXLRefinerModelMetadataItem = zSDXLRefinerModelField.deepPartial();
|
||||
const zModelMetadataItem = z.union([
|
||||
zMainModelField.deepPartial(),
|
||||
zONNXModelField.deepPartial(),
|
||||
]);
|
||||
const zModelMetadataItem = zMainModelField.deepPartial();
|
||||
const zVAEModelMetadataItem = zVAEModelField.deepPartial();
|
||||
export type LoRAMetadataItem = z.infer<typeof zLoRAMetadataItem>;
|
||||
export type ControlNetMetadataItem = z.infer<typeof zControlNetMetadataItem>;
|
||||
|
@ -273,11 +273,6 @@ const FIELD_TYPE_V1_TO_STATELESS_FIELD_TYPE_V2: {
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: true,
|
||||
},
|
||||
ONNXModelField: {
|
||||
name: 'ONNXModelField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
T2IAdapterField: {
|
||||
name: 'T2IAdapterField',
|
||||
isCollection: false,
|
||||
|
@ -14,7 +14,6 @@ import {
|
||||
INPAINT_IMAGE,
|
||||
LATENTS_TO_IMAGE,
|
||||
MAIN_MODEL_LOADER,
|
||||
ONNX_MODEL_LOADER,
|
||||
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
||||
SDXL_CANVAS_INPAINT_GRAPH,
|
||||
SDXL_CANVAS_OUTPAINT_GRAPH,
|
||||
@ -50,7 +49,6 @@ export const addVAEToGraph = (
|
||||
vae_model: vae,
|
||||
};
|
||||
}
|
||||
const isOnnxModel = modelLoaderNodeId == ONNX_MODEL_LOADER;
|
||||
|
||||
if (
|
||||
graph.id === TEXT_TO_IMAGE_GRAPH ||
|
||||
@ -61,7 +59,7 @@ export const addVAEToGraph = (
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
||||
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
@ -79,7 +77,7 @@ export const addVAEToGraph = (
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
||||
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: isUsingScaledDimensions ? LATENTS_TO_IMAGE : CANVAS_OUTPUT,
|
||||
@ -97,7 +95,7 @@ export const addVAEToGraph = (
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
||||
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: IMAGE_TO_LATENTS,
|
||||
@ -116,7 +114,7 @@ export const addVAEToGraph = (
|
||||
{
|
||||
source: {
|
||||
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
||||
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT_IMAGE,
|
||||
@ -126,7 +124,7 @@ export const addVAEToGraph = (
|
||||
{
|
||||
source: {
|
||||
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
||||
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
@ -136,7 +134,7 @@ export const addVAEToGraph = (
|
||||
{
|
||||
source: {
|
||||
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
||||
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
@ -150,7 +148,7 @@ export const addVAEToGraph = (
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
||||
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
@ -168,7 +166,7 @@ export const addVAEToGraph = (
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
||||
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||
|
@ -85,7 +85,6 @@ export const buildCanvasSDXLTextToImageGraph = (
|
||||
*/
|
||||
|
||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||
// TODO: Actually create the graph correctly for ONNX
|
||||
const graph: NonNullableGraph = {
|
||||
id: SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
|
||||
nodes: {
|
||||
|
@ -78,7 +78,6 @@ export const buildCanvasTextToImageGraph = (
|
||||
*/
|
||||
|
||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||
// TODO: Actually create the graph correctly for ONNX
|
||||
const graph: NonNullableGraph = {
|
||||
id: CANVAS_TEXT_TO_IMAGE_GRAPH,
|
||||
nodes: {
|
||||
|
@ -70,7 +70,6 @@ export const buildLinearTextToImageGraph = (
|
||||
|
||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||
|
||||
// TODO: Actually create the graph correctly for ONNX
|
||||
const graph: NonNullableGraph = {
|
||||
id: TEXT_TO_IMAGE_GRAPH,
|
||||
nodes: {
|
||||
|
@ -18,7 +18,6 @@ export const RANDOM_INT = 'rand_int';
|
||||
export const RANGE_OF_SIZE = 'range_of_size';
|
||||
export const ITERATE = 'iterate';
|
||||
export const MAIN_MODEL_LOADER = 'main_model_loader';
|
||||
export const ONNX_MODEL_LOADER = 'onnx_model_loader';
|
||||
export const VAE_LOADER = 'vae_loader';
|
||||
export const LORA_LOADER = 'lora_loader';
|
||||
export const CLIP_SKIP = 'clip_skip';
|
||||
|
@ -24,7 +24,14 @@ const RESERVED_INPUT_FIELD_NAMES = ['id', 'type', 'use_cache'];
|
||||
const RESERVED_OUTPUT_FIELD_NAMES = ['type'];
|
||||
const RESERVED_FIELD_TYPES = ['IsIntermediate'];
|
||||
|
||||
const invocationDenylist: string[] = ['graph', 'linear_ui_output'];
|
||||
const invocationDenylist: string[] = [
|
||||
'graph',
|
||||
'linear_ui_output',
|
||||
'l2i_onnx',
|
||||
'prompt_onnx',
|
||||
't2l_onnx',
|
||||
'onnx_model_loader',
|
||||
];
|
||||
|
||||
const isReservedInputField = (nodeType: string, fieldName: string) => {
|
||||
if (RESERVED_INPUT_FIELD_NAMES.includes(fieldName)) {
|
||||
|
Reference in New Issue
Block a user