feat(ui): finalize base model compatibility for lora, ti, vae

This commit is contained in:
psychedelicious
2023-07-07 21:23:03 +10:00
parent a9a4081f51
commit 8457fcf7d3
13 changed files with 187 additions and 113 deletions

View File

@ -1,6 +1,10 @@
import { useAppToaster } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { isImageField } from 'services/api/guards';
import { ImageDTO } from 'services/api/types';
import { initialImageSelected, modelSelected } from '../store/actions';
import {
setCfgScale,
setHeight,
@ -12,14 +16,10 @@ import {
setSteps,
setWidth,
} from '../store/generationSlice';
import { isImageField } from 'services/api/guards';
import { initialImageSelected, modelSelected } from '../store/actions';
import { useAppToaster } from 'app/components/Toaster';
import { ImageDTO } from 'services/api/types';
import {
isValidCfgScale,
isValidHeight,
isValidModel,
isValidMainModel,
isValidNegativePrompt,
isValidPositivePrompt,
isValidScheduler,
@ -158,7 +158,7 @@ export const useRecallParameters = () => {
*/
const recallModel = useCallback(
(model: unknown) => {
if (!isValidModel(model)) {
if (!isValidMainModel(model)) {
parameterNotSetToast();
return;
}
@ -295,7 +295,7 @@ export const useRecallParameters = () => {
if (isValidCfgScale(cfg_scale)) {
dispatch(setCfgScale(cfg_scale));
}
if (isValidModel(model)) {
if (isValidMainModel(model)) {
dispatch(modelSelected(model));
}
if (isValidPositivePrompt(positive_conditioning)) {

View File

@ -9,14 +9,16 @@ import { clipSkipMap } from '../components/Parameters/Advanced/ParamClipSkip';
import {
CfgScaleParam,
HeightParam,
ModelParam,
MainModelParam,
NegativePromptParam,
PositivePromptParam,
SchedulerParam,
SeedParam,
StepsParam,
StrengthParam,
VaeModelParam,
WidthParam,
zMainModel,
} from './parameterZodSchemas';
export interface GenerationState {
@ -48,8 +50,8 @@ export interface GenerationState {
shouldUseSymmetry: boolean;
horizontalSymmetrySteps: number;
verticalSymmetrySteps: number;
model: ModelParam;
vae: VAEParam;
model: MainModelParam | null;
vae: VaeModelParam | null;
seamlessXAxis: boolean;
seamlessYAxis: boolean;
clipSkip: number;
@ -84,7 +86,7 @@ export const initialGenerationState: GenerationState = {
horizontalSymmetrySteps: 0,
verticalSymmetrySteps: 0,
model: null,
vae: '',
vae: null,
seamlessXAxis: false,
seamlessYAxis: false,
clipSkip: 0,
@ -221,12 +223,17 @@ export const generationSlice = createSlice({
const { maxClip } = clipSkipMap[base_model as keyof typeof clipSkipMap];
state.clipSkip = clamp(state.clipSkip, 0, maxClip);
state.model = { id: action.payload, base_model, name, type };
state.model = zMainModel.parse({
id: action.payload,
base_model,
name,
type,
});
},
modelChanged: (state, action: PayloadAction<ModelParam>) => {
modelChanged: (state, action: PayloadAction<MainModelParam>) => {
state.model = action.payload;
},
vaeSelected: (state, action: PayloadAction<string>) => {
vaeSelected: (state, action: PayloadAction<VaeModelParam | null>) => {
state.vae = action.payload;
},
setClipSkip: (state, action: PayloadAction<number>) => {
@ -236,14 +243,14 @@ export const generationSlice = createSlice({
extraReducers: (builder) => {
builder.addCase(configChanged, (state, action) => {
const defaultModel = action.payload.sd?.defaultModel;
if (defaultModel && !state.model) {
const [base_model, model_type, model_name] = defaultModel.split('/');
state.model = {
state.model = zMainModel.parse({
id: defaultModel,
name: model_name,
type: model_type,
base_model: base_model,
};
base_model,
});
}
});
builder.addCase(setShouldShowAdvancedOptions, (state, action) => {

View File

@ -126,35 +126,63 @@ export type HeightParam = z.infer<typeof zHeight>;
export const isValidHeight = (val: unknown): val is HeightParam =>
zHeight.safeParse(val).success;
const zBaseModel = z.enum(['sd-1', 'sd-2']);
export type BaseModelParam = z.infer<typeof zBaseModel>;
/**
* Zod schema for model parameter
* TODO: Make this a dynamically generated enum?
*/
const zModel = z.object({
export const zMainModel = z.object({
id: z.string(),
name: z.string(),
type: z.string(),
base_model: z.string(),
base_model: zBaseModel,
});
/**
* Type alias for model parameter, inferred from its zod schema
*/
export type ModelParam = z.infer<typeof zModel> | null;
/**
* Zod schema for VAE parameter
* TODO: Make this a dynamically generated enum?
*/
export const zVAE = z.string();
/**
* Type alias for model parameter, inferred from its zod schema
*/
export type VAEParam = z.infer<typeof zVAE>;
export type MainModelParam = z.infer<typeof zMainModel>;
/**
* Validates/type-guards a value as a model parameter
*/
export const isValidModel = (val: unknown): val is ModelParam =>
zModel.safeParse(val).success;
export const isValidMainModel = (val: unknown): val is MainModelParam =>
zMainModel.safeParse(val).success;
/**
* Zod schema for VAE parameter
*/
export const zVaeModel = z.object({
id: z.string(),
name: z.string(),
base_model: zBaseModel,
});
/**
* Type alias for model parameter, inferred from its zod schema
*/
export type VaeModelParam = z.infer<typeof zVaeModel>;
/**
* Validates/type-guards a value as a model parameter
*/
export const isValidVaeModel = (val: unknown): val is VaeModelParam =>
zVaeModel.safeParse(val).success;
/**
* Zod schema for LoRA
*/
export const zLoRAModel = z.object({
id: z.string(),
name: z.string(),
base_model: zBaseModel,
});
/**
* Type alias for model parameter, inferred from its zod schema
*/
export type LoRAModelParam = z.infer<typeof zLoRAModel>;
/**
* Validates/type-guards a value as a model parameter
*/
export const isValidLoRAModel = (val: unknown): val is LoRAModelParam =>
zLoRAModel.safeParse(val).success;
/**
* Zod schema for l2l strength parameter