change model store to object, update main model and vae dropdowns

This commit is contained in:
Mary Hipp 2023-07-06 11:54:16 -04:00 committed by psychedelicious
parent 909fe047e4
commit 6356dc335f
10 changed files with 55 additions and 50 deletions

View File

@ -16,10 +16,12 @@ export const addVAEToGraph = (
graph: NonNullableGraph, graph: NonNullableGraph,
state: RootState state: RootState
): void => { ): void => {
const { vae: vaeId } = state.generation; const { vae } = state.generation;
const vae_model = modelIdToVAEModelField(vaeId); const vae_model = modelIdToVAEModelField(vae?.id || '');
if (vaeId !== 'auto') { const isAutoVae = vae?.id === 'auto';
if (!isAutoVae) {
graph.nodes[VAE_LOADER] = { graph.nodes[VAE_LOADER] = {
type: 'vae_loader', type: 'vae_loader',
id: VAE_LOADER, id: VAE_LOADER,
@ -30,7 +32,7 @@ export const addVAEToGraph = (
if (graph.id === TEXT_TO_IMAGE_GRAPH || graph.id === IMAGE_TO_IMAGE_GRAPH) { if (graph.id === TEXT_TO_IMAGE_GRAPH || graph.id === IMAGE_TO_IMAGE_GRAPH) {
graph.edges.push({ graph.edges.push({
source: { source: {
node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER, node_id: isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {
@ -43,7 +45,7 @@ export const addVAEToGraph = (
if (graph.id === IMAGE_TO_IMAGE_GRAPH) { if (graph.id === IMAGE_TO_IMAGE_GRAPH) {
graph.edges.push({ graph.edges.push({
source: { source: {
node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER, node_id: isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {
@ -56,7 +58,7 @@ export const addVAEToGraph = (
if (graph.id === INPAINT_GRAPH) { if (graph.id === INPAINT_GRAPH) {
graph.edges.push({ graph.edges.push({
source: { source: {
node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER, node_id: isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {

View File

@ -36,7 +36,7 @@ export const buildCanvasImageToImageGraph = (
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: modelId, model: currentModel,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -50,7 +50,7 @@ export const buildCanvasImageToImageGraph = (
// The bounding box determines width and height, not the width and height params // The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions; const { width, height } = state.canvas.boundingBoxDimensions;
const model = modelIdToMainModelField(modelId); const model = modelIdToMainModelField(currentModel?.id || '');
/** /**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the

View File

@ -35,7 +35,7 @@ export const buildCanvasInpaintGraph = (
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: modelId, model: currentModel,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -59,7 +59,7 @@ export const buildCanvasInpaintGraph = (
// We may need to set the inpaint width and height to scale the image // We may need to set the inpaint width and height to scale the image
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas; const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const model = modelIdToMainModelField(modelId); const model = modelIdToMainModelField(currentModel?.id || '');
const graph: NonNullableGraph = { const graph: NonNullableGraph = {
id: INPAINT_GRAPH, id: INPAINT_GRAPH,

View File

@ -25,7 +25,7 @@ export const buildCanvasTextToImageGraph = (
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: modelId, model: currentModel,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -38,7 +38,7 @@ export const buildCanvasTextToImageGraph = (
// The bounding box determines width and height, not the width and height params // The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions; const { width, height } = state.canvas.boundingBoxDimensions;
const model = modelIdToMainModelField(modelId); const model = modelIdToMainModelField(currentModel?.id || '');
/** /**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the

View File

@ -38,7 +38,7 @@ export const buildLinearImageToImageGraph = (
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: modelId, model: currentModel,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -73,7 +73,7 @@ export const buildLinearImageToImageGraph = (
throw new Error('No initial image found in state'); throw new Error('No initial image found in state');
} }
const model = modelIdToMainModelField(modelId); const model = modelIdToMainModelField(currentModel?.id || '');
// copy-pasted graph from node editor, filled in with state values & friendly node ids // copy-pasted graph from node editor, filled in with state values & friendly node ids
const graph: NonNullableGraph = { const graph: NonNullableGraph = {

View File

@ -22,7 +22,7 @@ export const buildLinearTextToImageGraph = (
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: modelId, model: currentModel,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -31,7 +31,7 @@ export const buildLinearTextToImageGraph = (
clipSkip, clipSkip,
} = state.generation; } = state.generation;
const model = modelIdToMainModelField(modelId); const model = modelIdToMainModelField(currentModel?.id || '');
/** /**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the

View File

@ -16,7 +16,6 @@ import {
SeedParam, SeedParam,
StepsParam, StepsParam,
StrengthParam, StrengthParam,
VAEParam,
WidthParam, WidthParam,
} from './parameterZodSchemas'; } from './parameterZodSchemas';
@ -50,7 +49,7 @@ export interface GenerationState {
horizontalSymmetrySteps: number; horizontalSymmetrySteps: number;
verticalSymmetrySteps: number; verticalSymmetrySteps: number;
model: ModelParam; model: ModelParam;
vae: VAEParam; vae: ModelParam;
seamlessXAxis: boolean; seamlessXAxis: boolean;
seamlessYAxis: boolean; seamlessYAxis: boolean;
clipSkip: number; clipSkip: number;
@ -84,8 +83,8 @@ export const initialGenerationState: GenerationState = {
shouldUseSymmetry: false, shouldUseSymmetry: false,
horizontalSymmetrySteps: 0, horizontalSymmetrySteps: 0,
verticalSymmetrySteps: 0, verticalSymmetrySteps: 0,
model: '', model: null,
vae: '', vae: null,
seamlessXAxis: false, seamlessXAxis: false,
seamlessYAxis: false, seamlessYAxis: false,
clipSkip: 0, clipSkip: 0,
@ -216,16 +215,17 @@ export const generationSlice = createSlice({
state.initialImage = { imageName: image_name, width, height }; state.initialImage = { imageName: image_name, width, height };
}, },
modelSelected: (state, action: PayloadAction<string>) => { modelSelected: (state, action: PayloadAction<string>) => {
state.model = action.payload; const [base_model, type, name] = action.payload.split('/');
// Clamp ClipSkip Based On Selected Model // Clamp ClipSkip Based On Selected Model
const clipSkipMax = const { maxClip } = clipSkipMap[base_model as keyof typeof clipSkipMap];
clipSkipMap[action.payload.split('/')[0] as keyof typeof clipSkipMap] state.clipSkip = clamp(state.clipSkip, 0, maxClip);
.maxClip;
state.clipSkip = clamp(state.clipSkip, 0, clipSkipMax); state.model = { id: action.payload, base_model, name, type };
}, },
vaeSelected: (state, action: PayloadAction<string>) => { vaeSelected: (state, action: PayloadAction<string>) => {
state.vae = action.payload; const [base_model, type, name] = action.payload.split('/');
state.vae = { id: action.payload, base_model, name, type };
}, },
setClipSkip: (state, action: PayloadAction<number>) => { setClipSkip: (state, action: PayloadAction<number>) => {
state.clipSkip = action.payload; state.clipSkip = action.payload;
@ -235,7 +235,13 @@ export const generationSlice = createSlice({
builder.addCase(configChanged, (state, action) => { builder.addCase(configChanged, (state, action) => {
const defaultModel = action.payload.sd?.defaultModel; const defaultModel = action.payload.sd?.defaultModel;
if (defaultModel && !state.model) { if (defaultModel && !state.model) {
state.model = defaultModel; const [base_model, model_type, model_name] = defaultModel.split('/');
state.model = {
id: defaultModel,
name: model_name,
type: model_type,
base_model: base_model,
};
} }
}); });
builder.addCase(setShouldShowAdvancedOptions, (state, action) => { builder.addCase(setShouldShowAdvancedOptions, (state, action) => {

View File

@ -130,20 +130,21 @@ export const isValidHeight = (val: unknown): val is HeightParam =>
* Zod schema for model parameter * Zod schema for model parameter
* TODO: Make this a dynamically generated enum? * TODO: Make this a dynamically generated enum?
*/ */
export const zModel = z.string(); const zModel = z.object({
id: z.string(),
name: z.string(),
type: z.string(),
base_model: z.string(),
});
/** /**
* Type alias for model parameter, inferred from its zod schema * Type alias for model parameter, inferred from its zod schema
*/ */
export type ModelParam = z.infer<typeof zModel>; export type ModelParam = z.infer<typeof zModel> | null;
/**
* Zod schema for VAE parameter
* TODO: Make this a dynamically generated enum?
*/
export const zVAE = z.string();
/** /**
* Type alias for model parameter, inferred from its zod schema * Type alias for model parameter, inferred from its zod schema
*/ */
export type VAEParam = z.infer<typeof zVAE>; export type VAEParam = z.infer<typeof zModel> | null;
/** /**
* Validates/type-guards a value as a model parameter * Validates/type-guards a value as a model parameter
*/ */

View File

@ -19,7 +19,7 @@ const ModelSelect = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const selectedModelId = useAppSelector( const currentModel = useAppSelector(
(state: RootState) => state.generation.model (state: RootState) => state.generation.model
); );
@ -48,8 +48,8 @@ const ModelSelect = () => {
}, [mainModels]); }, [mainModels]);
const selectedModel = useMemo( const selectedModel = useMemo(
() => mainModels?.entities[selectedModelId], () => mainModels?.entities[currentModel?.id || ''],
[mainModels?.entities, selectedModelId] [mainModels?.entities, currentModel]
); );
const handleChangeModel = useCallback( const handleChangeModel = useCallback(
@ -63,10 +63,6 @@ const ModelSelect = () => {
); );
useEffect(() => { useEffect(() => {
if (selectedModelId && mainModels?.ids.includes(selectedModelId)) {
return;
}
const firstModel = mainModels?.ids[0]; const firstModel = mainModels?.ids[0];
if (!isString(firstModel)) { if (!isString(firstModel)) {
@ -74,7 +70,7 @@ const ModelSelect = () => {
} }
handleChangeModel(firstModel); handleChangeModel(firstModel);
}, [handleChangeModel, mainModels?.ids, selectedModelId]); }, [handleChangeModel, mainModels?.ids]);
return isLoading ? ( return isLoading ? (
<IAIMantineSelect <IAIMantineSelect
@ -87,7 +83,7 @@ const ModelSelect = () => {
<IAIMantineSelect <IAIMantineSelect
tooltip={selectedModel?.description} tooltip={selectedModel?.description}
label={t('modelManager.model')} label={t('modelManager.model')}
value={selectedModelId} value={selectedModel?.id}
placeholder={data.length > 0 ? 'Select a model' : 'No models detected!'} placeholder={data.length > 0 ? 'Select a model' : 'No models detected!'}
data={data} data={data}
error={data.length === 0} error={data.length === 0}

View File

@ -18,7 +18,7 @@ const VAESelect = () => {
const { data: vaeModels } = useGetVaeModelsQuery(); const { data: vaeModels } = useGetVaeModelsQuery();
const selectedModelId = useAppSelector( const currentModel = useAppSelector(
(state: RootState) => state.generation.vae (state: RootState) => state.generation.vae
); );
@ -51,8 +51,8 @@ const VAESelect = () => {
}, [vaeModels]); }, [vaeModels]);
const selectedModel = useMemo( const selectedModel = useMemo(
() => vaeModels?.entities[selectedModelId], () => vaeModels?.entities[currentModel?.id || ''],
[vaeModels?.entities, selectedModelId] [vaeModels?.entities, currentModel]
); );
const handleChangeModel = useCallback( const handleChangeModel = useCallback(
@ -66,17 +66,17 @@ const VAESelect = () => {
); );
useEffect(() => { useEffect(() => {
if (selectedModelId && vaeModels?.ids.includes(selectedModelId)) { if (currentModel?.id && vaeModels?.ids.includes(currentModel?.id)) {
return; return;
} }
handleChangeModel('auto'); handleChangeModel('auto');
}, [handleChangeModel, vaeModels?.ids, selectedModelId]); }, [handleChangeModel, vaeModels?.ids, currentModel?.id]);
return ( return (
<IAIMantineSelect <IAIMantineSelect
tooltip={selectedModel?.description} tooltip={selectedModel?.description}
label={t('modelManager.vae')} label={t('modelManager.vae')}
value={selectedModelId} value={currentModel?.id}
placeholder="Pick one" placeholder="Pick one"
data={data} data={data}
onChange={handleChangeModel} onChange={handleChangeModel}