From 61c426f5023785672e71b44e518f52e8b27e2f44 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Sun, 18 Jun 2023 08:27:13 +1200 Subject: [PATCH] feat: Enable 2.x Model Generation in Linear UI --- .../graphBuilders/buildImageToImageGraph.ts | 3 ++- .../graphBuilders/buildTextToImageGraph.ts | 3 ++- .../parameters/store/generationSlice.ts | 7 ++++++ .../system/components/ModelSelect.tsx | 24 +++++++++++++++++-- 4 files changed, 33 insertions(+), 4 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildImageToImageGraph.ts index 83dccc4894..9aeb992823 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildImageToImageGraph.ts @@ -40,6 +40,7 @@ export const buildImageToImageGraph = (state: RootState): Graph => { positivePrompt, negativePrompt, model, + currentModelType, cfgScale: cfg_scale, scheduler, steps, @@ -66,7 +67,7 @@ export const buildImageToImageGraph = (state: RootState): Graph => { // Create the model loader node const modelLoaderNode: SD1ModelLoaderInvocation | SD2ModelLoaderInvocation = { id: MODEL_LOADER, - type: 'sd1_model_loader', + type: currentModelType, model_name: model, }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildTextToImageGraph.ts index 3e049a953f..8545bebdf2 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildTextToImageGraph.ts @@ -30,6 +30,7 @@ const ITERATE = 'iterate'; export const buildTextToImageGraph = (state: RootState): Graph => { const { model, + currentModelType, positivePrompt, negativePrompt, cfgScale: cfg_scale, @@ -50,7 +51,7 @@ export const buildTextToImageGraph = (state: RootState): Graph => { // Create the model loader node const modelLoaderNode: SD1ModelLoaderInvocation | SD2ModelLoaderInvocation = { id: MODEL_LOADER, - type: 'sd1_model_loader', + type: currentModelType, model_name: model, }; diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index cdf26470d2..4244b512fb 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -1,6 +1,7 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; import { Scheduler } from 'app/constants'; +import { ModelLoaderTypes } from 'features/system/components/ModelSelect'; import { configChanged } from 'features/system/store/configSlice'; import { clamp, sortBy } from 'lodash-es'; import { ImageDTO } from 'services/api'; @@ -49,6 +50,7 @@ export interface GenerationState { horizontalSymmetrySteps: number; verticalSymmetrySteps: number; model: ModelParam; + currentModelType: ModelLoaderTypes; shouldUseSeamless: boolean; seamlessXAxis: boolean; seamlessYAxis: boolean; @@ -83,6 +85,7 @@ export const initialGenerationState: GenerationState = { horizontalSymmetrySteps: 0, verticalSymmetrySteps: 0, model: '', + currentModelType: 'sd1_model_loader', shouldUseSeamless: false, seamlessXAxis: true, seamlessYAxis: true, @@ -217,6 +220,9 @@ export const generationSlice = createSlice({ modelSelected: (state, action: PayloadAction) => { state.model = action.payload; }, + setCurrentModelType: (state, action: PayloadAction) => { + state.currentModelType = action.payload; + }, }, extraReducers: (builder) => { builder.addCase(getModels.fulfilled, (state, action) => { @@ -277,6 +283,7 @@ export const { setVerticalSymmetrySteps, initialImageChanged, modelSelected, + setCurrentModelType, setShouldUseNoiseSettings, setSeamless, setSeamlessXAxis, diff --git a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx index a65c8501dc..bf0775d52e 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx +++ b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx @@ -1,6 +1,6 @@ import { createSelector } from '@reduxjs/toolkit'; import { isEqual } from 'lodash-es'; -import { memo, useCallback } from 'react'; +import { memo, useCallback, useEffect } from 'react'; import { useTranslation } from 'react-i18next'; import { RootState } from 'app/store/store'; @@ -9,7 +9,11 @@ import IAIMantineSelect, { IAISelectDataType, } from 'common/components/IAIMantineSelect'; import { generationSelector } from 'features/parameters/store/generationSelectors'; -import { modelSelected } from 'features/parameters/store/generationSlice'; +import { + modelSelected, + setCurrentModelType, +} from 'features/parameters/store/generationSlice'; + import { selectAllSD1Models, selectByIdSD1Models, @@ -55,12 +59,28 @@ export const modelSelector = createSelector( } ); +export type ModelLoaderTypes = 'sd1_model_loader' | 'sd2_model_loader'; + +const MODEL_LOADER_MAP = { + 'sd-1': 'sd1_model_loader', + 'sd-2': 'sd2_model_loader', +}; + const ModelSelect = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); const { selectedModel, sd1ModelData, sd2ModelData } = useAppSelector(modelSelector); + useEffect(() => { + if (selectedModel) + dispatch( + setCurrentModelType( + MODEL_LOADER_MAP[selectedModel?.base_model] as ModelLoaderTypes + ) + ); + }, [dispatch, selectedModel]); + const handleChangeModel = useCallback( (v: string | null) => { if (!v) {