Add model_type to the model state object

This commit is contained in:
Brandon Rising 2023-07-18 22:40:27 -04:00
parent e201ad2f51
commit 487455ef2e
8 changed files with 37 additions and 4 deletions

View File

@ -56,6 +56,7 @@ class MainModelField(BaseModel):
model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")
model_type: ModelType = Field(description="Model Type")
class LoRAModelField(BaseModel):

View File

@ -478,6 +478,7 @@ class OnnxModelField(BaseModel):
model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")
model_type: ModelType = Field(description="Model Type")
class OnnxModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""

View File

@ -6,6 +6,8 @@ import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { BaseModelType, OnnxModelField } from 'services/api/types';
import {
CLIP_SKIP,
LATENTS_TO_IMAGE,

View File

@ -1,4 +1,4 @@
import { BaseModelType, MainModelField } from 'services/api/types';
import { BaseModelType, MainModelField, ModelType } from 'services/api/types';
/**
* Crudely converts a model id to a main model field
@ -9,6 +9,7 @@ export const modelIdToMainModelField = (modelId: string): MainModelField => {
const field: MainModelField = {
base_model: base_model as BaseModelType,
model_type: model_type as ModelType,
model_name,
};

View File

@ -1,4 +1,4 @@
import { BaseModelType, OnnxModelField } from 'services/api/types';
import { BaseModelType, OnnxModelField, ModelType } from 'services/api/types';
/**
* Crudely converts a model id to a main model field
@ -10,6 +10,7 @@ export const modelIdToOnnxModelField = (modelId: string): OnnxModelField => {
const field: OnnxModelField = {
base_model: base_model as BaseModelType,
model_name,
model_type: model_type as ModelType,
};
return field;

View File

@ -260,6 +260,7 @@ export const generationSlice = createSlice({
id: defaultModel,
name: model_name,
base_model,
model_type,
});
}
});

View File

@ -127,6 +127,14 @@ export const isValidHeight = (val: unknown): val is HeightParam =>
zHeight.safeParse(val).success;
const zBaseModel = z.enum(['sd-1', 'sd-2']);
const zModelType = z.enum([
'vae',
'lora',
'onnx',
'main',
'controlnet',
'embedding',
]);
export type BaseModelParam = z.infer<typeof zBaseModel>;
@ -137,6 +145,7 @@ export type BaseModelParam = z.infer<typeof zBaseModel>;
export const zMainModel = z.object({
model_name: z.string(),
base_model: zBaseModel,
model_type: zModelType,
});
/**

View File

@ -2864,7 +2864,7 @@ export type components = {
clip?: components["schemas"]["ClipField"];
};
/**
* MainModelField
* MainModelField
* @description Main model field
*/
MainModelField: {
@ -2875,6 +2875,23 @@ export type components = {
model_name: string;
/** @description Base model */
base_model: components["schemas"]["BaseModelType"];
/** @description Model Type */
model_type: components["schemas"]["ModelType"];
};
/**
* OnnxModelField
* @description Onnx model field
*/
OnnxModelField: {
/**
* Model Name
* @description Name of the model
*/
model_name: string;
/** @description Base model */
base_model: components["schemas"]["BaseModelType"];
/** @description Model Type */
model_type: components["schemas"]["ModelType"];
};
/**
* MainModelLoaderInvocation
@ -3308,7 +3325,7 @@ export type components = {
* @description An enumeration.
* @enum {string}
*/
ModelType: "main" | "vae" | "lora" | "controlnet" | "embedding";
ModelType: "main" | "onnx" | "vae" | "lora" | "controlnet" | "embedding";
/**
* ModelVariantType
* @description An enumeration.