mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): post-onnx fixes
This commit is contained in:
committed by
Kent Keirsey
parent
e86925d424
commit
fb8f218901
@ -35,6 +35,7 @@ import {
|
||||
isValidSDXLNegativeStylePrompt,
|
||||
isValidSDXLPositiveStylePrompt,
|
||||
isValidSDXLRefinerAestheticScore,
|
||||
isValidSDXLRefinerModel,
|
||||
isValidSDXLRefinerStart,
|
||||
isValidScheduler,
|
||||
isValidSeed,
|
||||
@ -381,7 +382,7 @@ export const useRecallParameters = () => {
|
||||
dispatch(setNegativeStylePromptSDXL(negative_style_prompt));
|
||||
}
|
||||
|
||||
if (isValidMainModel(refiner_model)) {
|
||||
if (isValidSDXLRefinerModel(refiner_model)) {
|
||||
dispatch(refinerModelChanged(refiner_model));
|
||||
}
|
||||
|
||||
|
@ -3,13 +3,14 @@ import { createSlice } from '@reduxjs/toolkit';
|
||||
import { roundToMultiple } from 'common/util/roundDownToMultiple';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
import { clamp } from 'lodash-es';
|
||||
import { ImageDTO, MainModelField, OnnxModelField } from 'services/api/types';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { clipSkipMap } from '../types/constants';
|
||||
import {
|
||||
CfgScaleParam,
|
||||
HeightParam,
|
||||
MainModelParam,
|
||||
NegativePromptParam,
|
||||
OnnxModelParam,
|
||||
PositivePromptParam,
|
||||
PrecisionParam,
|
||||
SchedulerParam,
|
||||
@ -50,7 +51,7 @@ export interface GenerationState {
|
||||
shouldUseSymmetry: boolean;
|
||||
horizontalSymmetrySteps: number;
|
||||
verticalSymmetrySteps: number;
|
||||
model: MainModelField | OnnxModelField | null;
|
||||
model: MainModelParam | OnnxModelParam | null;
|
||||
vae: VaeModelParam | null;
|
||||
vaePrecision: PrecisionParam;
|
||||
seamlessXAxis: boolean;
|
||||
@ -229,7 +230,10 @@ export const generationSlice = createSlice({
|
||||
const { image_name, width, height } = action.payload;
|
||||
state.initialImage = { imageName: image_name, width, height };
|
||||
},
|
||||
modelChanged: (state, action: PayloadAction<MainModelParam | null>) => {
|
||||
modelChanged: (
|
||||
state,
|
||||
action: PayloadAction<MainModelParam | OnnxModelParam | null>
|
||||
) => {
|
||||
state.model = action.payload;
|
||||
|
||||
if (state.model === null) {
|
||||
|
@ -210,42 +210,70 @@ export type HeightParam = z.infer<typeof zHeight>;
|
||||
export const isValidHeight = (val: unknown): val is HeightParam =>
|
||||
zHeight.safeParse(val).success;
|
||||
|
||||
const zModelType = z.enum([
|
||||
'vae',
|
||||
'lora',
|
||||
'onnx',
|
||||
'main',
|
||||
'controlnet',
|
||||
'embedding',
|
||||
]);
|
||||
const zBaseModel = z.enum(['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
||||
|
||||
export type BaseModelParam = z.infer<typeof zBaseModel>;
|
||||
|
||||
/**
|
||||
* Zod schema for model parameter
|
||||
* Zod schema for main model parameter
|
||||
* TODO: Make this a dynamically generated enum?
|
||||
*/
|
||||
export const zMainModel = z.object({
|
||||
model_name: z.string().min(1),
|
||||
base_model: zBaseModel,
|
||||
model_type: zModelType,
|
||||
model_type: z.literal('main'),
|
||||
});
|
||||
|
||||
/**
|
||||
* Type alias for model parameter, inferred from its zod schema
|
||||
* Type alias for main model parameter, inferred from its zod schema
|
||||
*/
|
||||
export type MainModelParam = z.infer<typeof zMainModel>;
|
||||
|
||||
/**
|
||||
* Type alias for model parameter, inferred from its zod schema
|
||||
*/
|
||||
export type OnnxModelParam = z.infer<typeof zMainModel>;
|
||||
/**
|
||||
* Validates/type-guards a value as a model parameter
|
||||
* Validates/type-guards a value as a main model parameter
|
||||
*/
|
||||
export const isValidMainModel = (val: unknown): val is MainModelParam =>
|
||||
zMainModel.safeParse(val).success;
|
||||
|
||||
/**
|
||||
* Zod schema for SDXL refiner model parameter
|
||||
* TODO: Make this a dynamically generated enum?
|
||||
*/
|
||||
export const zSDXLRefinerModel = z.object({
|
||||
model_name: z.string().min(1),
|
||||
base_model: z.literal('sdxl-refiner'),
|
||||
model_type: z.literal('main'),
|
||||
});
|
||||
/**
|
||||
* Type alias for SDXL refiner model parameter, inferred from its zod schema
|
||||
*/
|
||||
export type SDXLRefinerModelParam = z.infer<typeof zSDXLRefinerModel>;
|
||||
/**
|
||||
* Validates/type-guards a value as a SDXL refiner model parameter
|
||||
*/
|
||||
export const isValidSDXLRefinerModel = (
|
||||
val: unknown
|
||||
): val is SDXLRefinerModelParam => zSDXLRefinerModel.safeParse(val).success;
|
||||
|
||||
/**
|
||||
* Zod schema for Onnx model parameter
|
||||
* TODO: Make this a dynamically generated enum?
|
||||
*/
|
||||
export const zOnnxModel = z.object({
|
||||
model_name: z.string().min(1),
|
||||
base_model: zBaseModel,
|
||||
model_type: z.literal('onnx'),
|
||||
});
|
||||
/**
|
||||
* Type alias for Onnx model parameter, inferred from its zod schema
|
||||
*/
|
||||
export type OnnxModelParam = z.infer<typeof zOnnxModel>;
|
||||
/**
|
||||
* Validates/type-guards a value as a Onnx model parameter
|
||||
*/
|
||||
export const isValidOnnxModel = (val: unknown): val is OnnxModelParam =>
|
||||
zOnnxModel.safeParse(val).success;
|
||||
|
||||
export const zMainOrOnnxModel = z.union([zMainModel, zOnnxModel]);
|
||||
|
||||
/**
|
||||
* Zod schema for VAE parameter
|
||||
*/
|
||||
|
@ -1,16 +1,17 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import {
|
||||
MainModelParam,
|
||||
zMainModel,
|
||||
OnnxModelParam,
|
||||
zMainOrOnnxModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
|
||||
export const modelIdToMainModelParam = (
|
||||
mainModelId: string
|
||||
): MainModelParam | undefined => {
|
||||
): OnnxModelParam | MainModelParam | undefined => {
|
||||
const log = logger('models');
|
||||
const [base_model, model_type, model_name] = mainModelId.split('/');
|
||||
|
||||
const result = zMainModel.safeParse({
|
||||
const result = zMainOrOnnxModel.safeParse({
|
||||
base_model,
|
||||
model_name,
|
||||
model_type,
|
||||
|
@ -0,0 +1,31 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import {
|
||||
SDXLRefinerModelParam,
|
||||
zSDXLRefinerModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
|
||||
export const modelIdToSDXLRefinerModelParam = (
|
||||
mainModelId: string
|
||||
): SDXLRefinerModelParam | undefined => {
|
||||
const log = logger('models');
|
||||
const [base_model, model_type, model_name] = mainModelId.split('/');
|
||||
|
||||
const result = zSDXLRefinerModel.safeParse({
|
||||
base_model,
|
||||
model_name,
|
||||
model_type,
|
||||
});
|
||||
|
||||
if (!result.success) {
|
||||
log.error(
|
||||
{
|
||||
mainModelId,
|
||||
errors: result.error.format(),
|
||||
},
|
||||
'Failed to parse main model id'
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
return result.data;
|
||||
};
|
Reference in New Issue
Block a user