feat: Enable 2.x Model Generation in Linear UI

This commit is contained in:
blessedcoolant 2023-06-18 08:27:13 +12:00
parent bf0577c882
commit 61c426f502
4 changed files with 33 additions and 4 deletions

View File

@ -40,6 +40,7 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model, model,
currentModelType,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -66,7 +67,7 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
// Create the model loader node // Create the model loader node
const modelLoaderNode: SD1ModelLoaderInvocation | SD2ModelLoaderInvocation = { const modelLoaderNode: SD1ModelLoaderInvocation | SD2ModelLoaderInvocation = {
id: MODEL_LOADER, id: MODEL_LOADER,
type: 'sd1_model_loader', type: currentModelType,
model_name: model, model_name: model,
}; };

View File

@ -30,6 +30,7 @@ const ITERATE = 'iterate';
export const buildTextToImageGraph = (state: RootState): Graph => { export const buildTextToImageGraph = (state: RootState): Graph => {
const { const {
model, model,
currentModelType,
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
cfgScale: cfg_scale, cfgScale: cfg_scale,
@ -50,7 +51,7 @@ export const buildTextToImageGraph = (state: RootState): Graph => {
// Create the model loader node // Create the model loader node
const modelLoaderNode: SD1ModelLoaderInvocation | SD2ModelLoaderInvocation = { const modelLoaderNode: SD1ModelLoaderInvocation | SD2ModelLoaderInvocation = {
id: MODEL_LOADER, id: MODEL_LOADER,
type: 'sd1_model_loader', type: currentModelType,
model_name: model, model_name: model,
}; };

View File

@ -1,6 +1,7 @@
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import { Scheduler } from 'app/constants'; import { Scheduler } from 'app/constants';
import { ModelLoaderTypes } from 'features/system/components/ModelSelect';
import { configChanged } from 'features/system/store/configSlice'; import { configChanged } from 'features/system/store/configSlice';
import { clamp, sortBy } from 'lodash-es'; import { clamp, sortBy } from 'lodash-es';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
@ -49,6 +50,7 @@ export interface GenerationState {
horizontalSymmetrySteps: number; horizontalSymmetrySteps: number;
verticalSymmetrySteps: number; verticalSymmetrySteps: number;
model: ModelParam; model: ModelParam;
currentModelType: ModelLoaderTypes;
shouldUseSeamless: boolean; shouldUseSeamless: boolean;
seamlessXAxis: boolean; seamlessXAxis: boolean;
seamlessYAxis: boolean; seamlessYAxis: boolean;
@ -83,6 +85,7 @@ export const initialGenerationState: GenerationState = {
horizontalSymmetrySteps: 0, horizontalSymmetrySteps: 0,
verticalSymmetrySteps: 0, verticalSymmetrySteps: 0,
model: '', model: '',
currentModelType: 'sd1_model_loader',
shouldUseSeamless: false, shouldUseSeamless: false,
seamlessXAxis: true, seamlessXAxis: true,
seamlessYAxis: true, seamlessYAxis: true,
@ -217,6 +220,9 @@ export const generationSlice = createSlice({
modelSelected: (state, action: PayloadAction<string>) => { modelSelected: (state, action: PayloadAction<string>) => {
state.model = action.payload; state.model = action.payload;
}, },
setCurrentModelType: (state, action: PayloadAction<ModelLoaderTypes>) => {
state.currentModelType = action.payload;
},
}, },
extraReducers: (builder) => { extraReducers: (builder) => {
builder.addCase(getModels.fulfilled, (state, action) => { builder.addCase(getModels.fulfilled, (state, action) => {
@ -277,6 +283,7 @@ export const {
setVerticalSymmetrySteps, setVerticalSymmetrySteps,
initialImageChanged, initialImageChanged,
modelSelected, modelSelected,
setCurrentModelType,
setShouldUseNoiseSettings, setShouldUseNoiseSettings,
setSeamless, setSeamless,
setSeamlessXAxis, setSeamlessXAxis,

View File

@ -1,6 +1,6 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import { memo, useCallback } from 'react'; import { memo, useCallback, useEffect } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
@ -9,7 +9,11 @@ import IAIMantineSelect, {
IAISelectDataType, IAISelectDataType,
} from 'common/components/IAIMantineSelect'; } from 'common/components/IAIMantineSelect';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import { modelSelected } from 'features/parameters/store/generationSlice'; import {
modelSelected,
setCurrentModelType,
} from 'features/parameters/store/generationSlice';
import { import {
selectAllSD1Models, selectAllSD1Models,
selectByIdSD1Models, selectByIdSD1Models,
@ -55,12 +59,28 @@ export const modelSelector = createSelector(
} }
); );
export type ModelLoaderTypes = 'sd1_model_loader' | 'sd2_model_loader';
const MODEL_LOADER_MAP = {
'sd-1': 'sd1_model_loader',
'sd-2': 'sd2_model_loader',
};
const ModelSelect = () => { const ModelSelect = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const { selectedModel, sd1ModelData, sd2ModelData } = const { selectedModel, sd1ModelData, sd2ModelData } =
useAppSelector(modelSelector); useAppSelector(modelSelector);
useEffect(() => {
if (selectedModel)
dispatch(
setCurrentModelType(
MODEL_LOADER_MAP[selectedModel?.base_model] as ModelLoaderTypes
)
);
}, [dispatch, selectedModel]);
const handleChangeModel = useCallback( const handleChangeModel = useCallback(
(v: string | null) => { (v: string | null) => {
if (!v) { if (!v) {