fix(ui): post-onnx fixes

This commit is contained in:
psychedelicious
2023-08-01 14:26:02 +10:00
committed by Kent Keirsey
parent e86925d424
commit fb8f218901
18 changed files with 229 additions and 106 deletions

View File

@ -35,6 +35,7 @@ import {
isValidSDXLNegativeStylePrompt,
isValidSDXLPositiveStylePrompt,
isValidSDXLRefinerAestheticScore,
isValidSDXLRefinerModel,
isValidSDXLRefinerStart,
isValidScheduler,
isValidSeed,
@ -381,7 +382,7 @@ export const useRecallParameters = () => {
dispatch(setNegativeStylePromptSDXL(negative_style_prompt));
}
if (isValidMainModel(refiner_model)) {
if (isValidSDXLRefinerModel(refiner_model)) {
dispatch(refinerModelChanged(refiner_model));
}

View File

@ -3,13 +3,14 @@ import { createSlice } from '@reduxjs/toolkit';
import { roundToMultiple } from 'common/util/roundDownToMultiple';
import { configChanged } from 'features/system/store/configSlice';
import { clamp } from 'lodash-es';
import { ImageDTO, MainModelField, OnnxModelField } from 'services/api/types';
import { ImageDTO } from 'services/api/types';
import { clipSkipMap } from '../types/constants';
import {
CfgScaleParam,
HeightParam,
MainModelParam,
NegativePromptParam,
OnnxModelParam,
PositivePromptParam,
PrecisionParam,
SchedulerParam,
@ -50,7 +51,7 @@ export interface GenerationState {
shouldUseSymmetry: boolean;
horizontalSymmetrySteps: number;
verticalSymmetrySteps: number;
model: MainModelField | OnnxModelField | null;
model: MainModelParam | OnnxModelParam | null;
vae: VaeModelParam | null;
vaePrecision: PrecisionParam;
seamlessXAxis: boolean;
@ -229,7 +230,10 @@ export const generationSlice = createSlice({
const { image_name, width, height } = action.payload;
state.initialImage = { imageName: image_name, width, height };
},
modelChanged: (state, action: PayloadAction<MainModelParam | null>) => {
modelChanged: (
state,
action: PayloadAction<MainModelParam | OnnxModelParam | null>
) => {
state.model = action.payload;
if (state.model === null) {

View File

@ -210,42 +210,70 @@ export type HeightParam = z.infer<typeof zHeight>;
export const isValidHeight = (val: unknown): val is HeightParam =>
zHeight.safeParse(val).success;
const zModelType = z.enum([
'vae',
'lora',
'onnx',
'main',
'controlnet',
'embedding',
]);
const zBaseModel = z.enum(['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
export type BaseModelParam = z.infer<typeof zBaseModel>;
/**
* Zod schema for model parameter
* Zod schema for main model parameter
* TODO: Make this a dynamically generated enum?
*/
export const zMainModel = z.object({
model_name: z.string().min(1),
base_model: zBaseModel,
model_type: zModelType,
model_type: z.literal('main'),
});
/**
* Type alias for model parameter, inferred from its zod schema
* Type alias for main model parameter, inferred from its zod schema
*/
export type MainModelParam = z.infer<typeof zMainModel>;
/**
* Type alias for model parameter, inferred from its zod schema
*/
export type OnnxModelParam = z.infer<typeof zMainModel>;
/**
* Validates/type-guards a value as a model parameter
* Validates/type-guards a value as a main model parameter
*/
export const isValidMainModel = (val: unknown): val is MainModelParam =>
zMainModel.safeParse(val).success;
/**
* Zod schema for SDXL refiner model parameter
* TODO: Make this a dynamically generated enum?
*/
export const zSDXLRefinerModel = z.object({
model_name: z.string().min(1),
base_model: z.literal('sdxl-refiner'),
model_type: z.literal('main'),
});
/**
* Type alias for SDXL refiner model parameter, inferred from its zod schema
*/
export type SDXLRefinerModelParam = z.infer<typeof zSDXLRefinerModel>;
/**
* Validates/type-guards a value as a SDXL refiner model parameter
*/
export const isValidSDXLRefinerModel = (
val: unknown
): val is SDXLRefinerModelParam => zSDXLRefinerModel.safeParse(val).success;
/**
* Zod schema for Onnx model parameter
* TODO: Make this a dynamically generated enum?
*/
export const zOnnxModel = z.object({
model_name: z.string().min(1),
base_model: zBaseModel,
model_type: z.literal('onnx'),
});
/**
* Type alias for Onnx model parameter, inferred from its zod schema
*/
export type OnnxModelParam = z.infer<typeof zOnnxModel>;
/**
* Validates/type-guards a value as a Onnx model parameter
*/
export const isValidOnnxModel = (val: unknown): val is OnnxModelParam =>
zOnnxModel.safeParse(val).success;
export const zMainOrOnnxModel = z.union([zMainModel, zOnnxModel]);
/**
* Zod schema for VAE parameter
*/

View File

@ -1,16 +1,17 @@
import { logger } from 'app/logging/logger';
import {
MainModelParam,
zMainModel,
OnnxModelParam,
zMainOrOnnxModel,
} from 'features/parameters/types/parameterSchemas';
export const modelIdToMainModelParam = (
mainModelId: string
): MainModelParam | undefined => {
): OnnxModelParam | MainModelParam | undefined => {
const log = logger('models');
const [base_model, model_type, model_name] = mainModelId.split('/');
const result = zMainModel.safeParse({
const result = zMainOrOnnxModel.safeParse({
base_model,
model_name,
model_type,

View File

@ -0,0 +1,31 @@
import { logger } from 'app/logging/logger';
import {
SDXLRefinerModelParam,
zSDXLRefinerModel,
} from 'features/parameters/types/parameterSchemas';
export const modelIdToSDXLRefinerModelParam = (
mainModelId: string
): SDXLRefinerModelParam | undefined => {
const log = logger('models');
const [base_model, model_type, model_name] = mainModelId.split('/');
const result = zSDXLRefinerModel.safeParse({
base_model,
model_name,
model_type,
});
if (!result.success) {
log.error(
{
mainModelId,
errors: result.error.format(),
},
'Failed to parse main model id'
);
return;
}
return result.data;
};