mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): finalize base model compatibility for lora, ti, vae
This commit is contained in:
@ -1,6 +1,10 @@
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { isImageField } from 'services/api/guards';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { initialImageSelected, modelSelected } from '../store/actions';
|
||||
import {
|
||||
setCfgScale,
|
||||
setHeight,
|
||||
@ -12,14 +16,10 @@ import {
|
||||
setSteps,
|
||||
setWidth,
|
||||
} from '../store/generationSlice';
|
||||
import { isImageField } from 'services/api/guards';
|
||||
import { initialImageSelected, modelSelected } from '../store/actions';
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import {
|
||||
isValidCfgScale,
|
||||
isValidHeight,
|
||||
isValidModel,
|
||||
isValidMainModel,
|
||||
isValidNegativePrompt,
|
||||
isValidPositivePrompt,
|
||||
isValidScheduler,
|
||||
@ -158,7 +158,7 @@ export const useRecallParameters = () => {
|
||||
*/
|
||||
const recallModel = useCallback(
|
||||
(model: unknown) => {
|
||||
if (!isValidModel(model)) {
|
||||
if (!isValidMainModel(model)) {
|
||||
parameterNotSetToast();
|
||||
return;
|
||||
}
|
||||
@ -295,7 +295,7 @@ export const useRecallParameters = () => {
|
||||
if (isValidCfgScale(cfg_scale)) {
|
||||
dispatch(setCfgScale(cfg_scale));
|
||||
}
|
||||
if (isValidModel(model)) {
|
||||
if (isValidMainModel(model)) {
|
||||
dispatch(modelSelected(model));
|
||||
}
|
||||
if (isValidPositivePrompt(positive_conditioning)) {
|
||||
|
@ -9,14 +9,16 @@ import { clipSkipMap } from '../components/Parameters/Advanced/ParamClipSkip';
|
||||
import {
|
||||
CfgScaleParam,
|
||||
HeightParam,
|
||||
ModelParam,
|
||||
MainModelParam,
|
||||
NegativePromptParam,
|
||||
PositivePromptParam,
|
||||
SchedulerParam,
|
||||
SeedParam,
|
||||
StepsParam,
|
||||
StrengthParam,
|
||||
VaeModelParam,
|
||||
WidthParam,
|
||||
zMainModel,
|
||||
} from './parameterZodSchemas';
|
||||
|
||||
export interface GenerationState {
|
||||
@ -48,8 +50,8 @@ export interface GenerationState {
|
||||
shouldUseSymmetry: boolean;
|
||||
horizontalSymmetrySteps: number;
|
||||
verticalSymmetrySteps: number;
|
||||
model: ModelParam;
|
||||
vae: VAEParam;
|
||||
model: MainModelParam | null;
|
||||
vae: VaeModelParam | null;
|
||||
seamlessXAxis: boolean;
|
||||
seamlessYAxis: boolean;
|
||||
clipSkip: number;
|
||||
@ -84,7 +86,7 @@ export const initialGenerationState: GenerationState = {
|
||||
horizontalSymmetrySteps: 0,
|
||||
verticalSymmetrySteps: 0,
|
||||
model: null,
|
||||
vae: '',
|
||||
vae: null,
|
||||
seamlessXAxis: false,
|
||||
seamlessYAxis: false,
|
||||
clipSkip: 0,
|
||||
@ -221,12 +223,17 @@ export const generationSlice = createSlice({
|
||||
const { maxClip } = clipSkipMap[base_model as keyof typeof clipSkipMap];
|
||||
state.clipSkip = clamp(state.clipSkip, 0, maxClip);
|
||||
|
||||
state.model = { id: action.payload, base_model, name, type };
|
||||
state.model = zMainModel.parse({
|
||||
id: action.payload,
|
||||
base_model,
|
||||
name,
|
||||
type,
|
||||
});
|
||||
},
|
||||
modelChanged: (state, action: PayloadAction<ModelParam>) => {
|
||||
modelChanged: (state, action: PayloadAction<MainModelParam>) => {
|
||||
state.model = action.payload;
|
||||
},
|
||||
vaeSelected: (state, action: PayloadAction<string>) => {
|
||||
vaeSelected: (state, action: PayloadAction<VaeModelParam | null>) => {
|
||||
state.vae = action.payload;
|
||||
},
|
||||
setClipSkip: (state, action: PayloadAction<number>) => {
|
||||
@ -236,14 +243,14 @@ export const generationSlice = createSlice({
|
||||
extraReducers: (builder) => {
|
||||
builder.addCase(configChanged, (state, action) => {
|
||||
const defaultModel = action.payload.sd?.defaultModel;
|
||||
|
||||
if (defaultModel && !state.model) {
|
||||
const [base_model, model_type, model_name] = defaultModel.split('/');
|
||||
state.model = {
|
||||
state.model = zMainModel.parse({
|
||||
id: defaultModel,
|
||||
name: model_name,
|
||||
type: model_type,
|
||||
base_model: base_model,
|
||||
};
|
||||
base_model,
|
||||
});
|
||||
}
|
||||
});
|
||||
builder.addCase(setShouldShowAdvancedOptions, (state, action) => {
|
||||
|
@ -126,35 +126,63 @@ export type HeightParam = z.infer<typeof zHeight>;
|
||||
export const isValidHeight = (val: unknown): val is HeightParam =>
|
||||
zHeight.safeParse(val).success;
|
||||
|
||||
const zBaseModel = z.enum(['sd-1', 'sd-2']);
|
||||
|
||||
export type BaseModelParam = z.infer<typeof zBaseModel>;
|
||||
|
||||
/**
|
||||
* Zod schema for model parameter
|
||||
* TODO: Make this a dynamically generated enum?
|
||||
*/
|
||||
const zModel = z.object({
|
||||
export const zMainModel = z.object({
|
||||
id: z.string(),
|
||||
name: z.string(),
|
||||
type: z.string(),
|
||||
base_model: z.string(),
|
||||
base_model: zBaseModel,
|
||||
});
|
||||
|
||||
/**
|
||||
* Type alias for model parameter, inferred from its zod schema
|
||||
*/
|
||||
export type ModelParam = z.infer<typeof zModel> | null;
|
||||
/**
|
||||
* Zod schema for VAE parameter
|
||||
* TODO: Make this a dynamically generated enum?
|
||||
*/
|
||||
export const zVAE = z.string();
|
||||
/**
|
||||
* Type alias for model parameter, inferred from its zod schema
|
||||
*/
|
||||
export type VAEParam = z.infer<typeof zVAE>;
|
||||
export type MainModelParam = z.infer<typeof zMainModel>;
|
||||
/**
|
||||
* Validates/type-guards a value as a model parameter
|
||||
*/
|
||||
export const isValidModel = (val: unknown): val is ModelParam =>
|
||||
zModel.safeParse(val).success;
|
||||
export const isValidMainModel = (val: unknown): val is MainModelParam =>
|
||||
zMainModel.safeParse(val).success;
|
||||
/**
|
||||
* Zod schema for VAE parameter
|
||||
*/
|
||||
export const zVaeModel = z.object({
|
||||
id: z.string(),
|
||||
name: z.string(),
|
||||
base_model: zBaseModel,
|
||||
});
|
||||
/**
|
||||
* Type alias for model parameter, inferred from its zod schema
|
||||
*/
|
||||
export type VaeModelParam = z.infer<typeof zVaeModel>;
|
||||
/**
|
||||
* Validates/type-guards a value as a model parameter
|
||||
*/
|
||||
export const isValidVaeModel = (val: unknown): val is VaeModelParam =>
|
||||
zVaeModel.safeParse(val).success;
|
||||
/**
|
||||
* Zod schema for LoRA
|
||||
*/
|
||||
export const zLoRAModel = z.object({
|
||||
id: z.string(),
|
||||
name: z.string(),
|
||||
base_model: zBaseModel,
|
||||
});
|
||||
/**
|
||||
* Type alias for model parameter, inferred from its zod schema
|
||||
*/
|
||||
export type LoRAModelParam = z.infer<typeof zLoRAModel>;
|
||||
/**
|
||||
* Validates/type-guards a value as a model parameter
|
||||
*/
|
||||
export const isValidLoRAModel = (val: unknown): val is LoRAModelParam =>
|
||||
zLoRAModel.safeParse(val).success;
|
||||
|
||||
/**
|
||||
* Zod schema for l2l strength parameter
|
||||
|
Reference in New Issue
Block a user