Add model_type to the model state object

This commit is contained in:
Brandon Rising 2023-07-18 22:40:27 -04:00
parent e201ad2f51
commit 487455ef2e
8 changed files with 37 additions and 4 deletions

View File

@ -56,6 +56,7 @@ class MainModelField(BaseModel):
model_name: str = Field(description="Name of the model") model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model") base_model: BaseModelType = Field(description="Base model")
model_type: ModelType = Field(description="Model Type")
class LoRAModelField(BaseModel): class LoRAModelField(BaseModel):

View File

@ -478,6 +478,7 @@ class OnnxModelField(BaseModel):
model_name: str = Field(description="Name of the model") model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model") base_model: BaseModelType = Field(description="Base model")
model_type: ModelType = Field(description="Model Type")
class OnnxModelLoaderInvocation(BaseInvocation): class OnnxModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels.""" """Loads a main model, outputting its submodels."""

View File

@ -6,6 +6,8 @@ import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
import { BaseModelType, OnnxModelField } from 'services/api/types';
import { import {
CLIP_SKIP, CLIP_SKIP,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,

View File

@ -1,4 +1,4 @@
import { BaseModelType, MainModelField } from 'services/api/types'; import { BaseModelType, MainModelField, ModelType } from 'services/api/types';
/** /**
* Crudely converts a model id to a main model field * Crudely converts a model id to a main model field
@ -9,6 +9,7 @@ export const modelIdToMainModelField = (modelId: string): MainModelField => {
const field: MainModelField = { const field: MainModelField = {
base_model: base_model as BaseModelType, base_model: base_model as BaseModelType,
model_type: model_type as ModelType,
model_name, model_name,
}; };

View File

@ -1,4 +1,4 @@
import { BaseModelType, OnnxModelField } from 'services/api/types'; import { BaseModelType, OnnxModelField, ModelType } from 'services/api/types';
/** /**
* Crudely converts a model id to a main model field * Crudely converts a model id to a main model field
@ -10,6 +10,7 @@ export const modelIdToOnnxModelField = (modelId: string): OnnxModelField => {
const field: OnnxModelField = { const field: OnnxModelField = {
base_model: base_model as BaseModelType, base_model: base_model as BaseModelType,
model_name, model_name,
model_type: model_type as ModelType,
}; };
return field; return field;

View File

@ -260,6 +260,7 @@ export const generationSlice = createSlice({
id: defaultModel, id: defaultModel,
name: model_name, name: model_name,
base_model, base_model,
model_type,
}); });
} }
}); });

View File

@ -127,6 +127,14 @@ export const isValidHeight = (val: unknown): val is HeightParam =>
zHeight.safeParse(val).success; zHeight.safeParse(val).success;
const zBaseModel = z.enum(['sd-1', 'sd-2']); const zBaseModel = z.enum(['sd-1', 'sd-2']);
const zModelType = z.enum([
'vae',
'lora',
'onnx',
'main',
'controlnet',
'embedding',
]);
export type BaseModelParam = z.infer<typeof zBaseModel>; export type BaseModelParam = z.infer<typeof zBaseModel>;
@ -137,6 +145,7 @@ export type BaseModelParam = z.infer<typeof zBaseModel>;
export const zMainModel = z.object({ export const zMainModel = z.object({
model_name: z.string(), model_name: z.string(),
base_model: zBaseModel, base_model: zBaseModel,
model_type: zModelType,
}); });
/** /**

View File

@ -2864,7 +2864,7 @@ export type components = {
clip?: components["schemas"]["ClipField"]; clip?: components["schemas"]["ClipField"];
}; };
/** /**
* MainModelField * MainModelField
* @description Main model field * @description Main model field
*/ */
MainModelField: { MainModelField: {
@ -2875,6 +2875,23 @@ export type components = {
model_name: string; model_name: string;
/** @description Base model */ /** @description Base model */
base_model: components["schemas"]["BaseModelType"]; base_model: components["schemas"]["BaseModelType"];
/** @description Model Type */
model_type: components["schemas"]["ModelType"];
};
/**
* OnnxModelField
* @description Onnx model field
*/
OnnxModelField: {
/**
* Model Name
* @description Name of the model
*/
model_name: string;
/** @description Base model */
base_model: components["schemas"]["BaseModelType"];
/** @description Model Type */
model_type: components["schemas"]["ModelType"];
}; };
/** /**
* MainModelLoaderInvocation * MainModelLoaderInvocation
@ -3308,7 +3325,7 @@ export type components = {
* @description An enumeration. * @description An enumeration.
* @enum {string} * @enum {string}
*/ */
ModelType: "main" | "vae" | "lora" | "controlnet" | "embedding"; ModelType: "main" | "onnx" | "vae" | "lora" | "controlnet" | "embedding";
/** /**
* ModelVariantType * ModelVariantType
* @description An enumeration. * @description An enumeration.