feat(ui): initial implementation of model loading

- Update model listing code to use `rtk-query`
- Update all graph generation to use new `pipeline_model_loader` node
This commit is contained in:
psychedelicious 2023-06-22 17:48:57 +10:00
parent 2a178f5a25
commit 339e7ce213
26 changed files with 281 additions and 386 deletions

View File

@ -24,6 +24,7 @@ import Toaster from './Toaster';
import DeleteImageModal from 'features/gallery/components/DeleteImageModal'; import DeleteImageModal from 'features/gallery/components/DeleteImageModal';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal'; import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
import { useListModelsQuery } from 'services/apiSlice';
const DEFAULT_CONFIG = {}; const DEFAULT_CONFIG = {};
@ -46,6 +47,18 @@ const App = ({
const isApplicationReady = useIsApplicationReady(); const isApplicationReady = useIsApplicationReady();
const { data: pipelineModels } = useListModelsQuery({
model_type: 'pipeline',
});
const { data: controlnetModels } = useListModelsQuery({
model_type: 'controlnet',
});
const { data: vaeModels } = useListModelsQuery({ model_type: 'vae' });
const { data: loraModels } = useListModelsQuery({ model_type: 'lora' });
const { data: embeddingModels } = useListModelsQuery({
model_type: 'embedding',
});
const [loadingOverridden, setLoadingOverridden] = useState(false); const [loadingOverridden, setLoadingOverridden] = useState(false);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();

View File

@ -5,7 +5,6 @@ import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersist
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist'; import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist'; import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist'; import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
import { modelsPersistDenylist } from 'features/system/store/modelsPersistDenylist';
import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist'; import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist';
import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist'; import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist';
import { omit } from 'lodash-es'; import { omit } from 'lodash-es';
@ -18,8 +17,6 @@ const serializationDenylist: {
gallery: galleryPersistDenylist, gallery: galleryPersistDenylist,
generation: generationPersistDenylist, generation: generationPersistDenylist,
lightbox: lightboxPersistDenylist, lightbox: lightboxPersistDenylist,
sd1models: modelsPersistDenylist,
sd2models: modelsPersistDenylist,
nodes: nodesPersistDenylist, nodes: nodesPersistDenylist,
postprocessing: postprocessingPersistDenylist, postprocessing: postprocessingPersistDenylist,
system: systemPersistDenylist, system: systemPersistDenylist,

View File

@ -7,8 +7,6 @@ import { initialNodesState } from 'features/nodes/store/nodesSlice';
import { initialGenerationState } from 'features/parameters/store/generationSlice'; import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice'; import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
import { initialConfigState } from 'features/system/store/configSlice'; import { initialConfigState } from 'features/system/store/configSlice';
import { sd1InitialPipelineModelsState } from 'features/system/store/models/sd1PipelineModelSlice';
import { sd2InitialPipelineModelsState } from 'features/system/store/models/sd2PipelineModelSlice';
import { initialSystemState } from 'features/system/store/systemSlice'; import { initialSystemState } from 'features/system/store/systemSlice';
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice'; import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
import { initialUIState } from 'features/ui/store/uiSlice'; import { initialUIState } from 'features/ui/store/uiSlice';
@ -22,8 +20,6 @@ const initialStates: {
gallery: initialGalleryState, gallery: initialGalleryState,
generation: initialGenerationState, generation: initialGenerationState,
lightbox: initialLightboxState, lightbox: initialLightboxState,
sd1PipelineModels: sd1InitialPipelineModelsState,
sd2PipelineModels: sd2InitialPipelineModelsState,
nodes: initialNodesState, nodes: initialNodesState,
postprocessing: initialPostprocessingState, postprocessing: initialPostprocessingState,
system: initialSystemState, system: initialSystemState,

View File

@ -1,7 +1,6 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { appSocketConnected, socketConnected } from 'services/events/actions'; import { appSocketConnected, socketConnected } from 'services/events/actions';
import { receivedPageOfImages } from 'services/thunks/image'; import { receivedPageOfImages } from 'services/thunks/image';
import { receivedModels } from 'services/thunks/model';
import { receivedOpenAPISchema } from 'services/thunks/schema'; import { receivedOpenAPISchema } from 'services/thunks/schema';
import { startAppListening } from '../..'; import { startAppListening } from '../..';
@ -15,8 +14,7 @@ export const addSocketConnectedEventListener = () => {
moduleLog.debug({ timestamp }, 'Connected'); moduleLog.debug({ timestamp }, 'Connected');
const { sd1pipelinemodels, sd2pipelinemodels, nodes, config, images } = const { nodes, config, images } = getState();
getState();
const { disabledTabs } = config; const { disabledTabs } = config;
@ -29,14 +27,6 @@ export const addSocketConnectedEventListener = () => {
); );
} }
if (!sd1pipelinemodels.ids.length) {
dispatch(receivedModels({ baseModel: 'sd-1', modelType: 'pipeline' }));
}
if (!sd2pipelinemodels.ids.length) {
dispatch(receivedModels({ baseModel: 'sd-2', modelType: 'pipeline' }));
}
if (!nodes.schema && !disabledTabs.includes('nodes')) { if (!nodes.schema && !disabledTabs.includes('nodes')) {
dispatch(receivedOpenAPISchema()); dispatch(receivedOpenAPISchema());
} }

View File

@ -28,11 +28,6 @@ import { listenerMiddleware } from './middleware/listenerMiddleware';
import { actionSanitizer } from './middleware/devtools/actionSanitizer'; import { actionSanitizer } from './middleware/devtools/actionSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist'; import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer'; import { stateSanitizer } from './middleware/devtools/stateSanitizer';
// Model Reducers
import sd1PipelineModelReducer from 'features/system/store/models/sd1PipelineModelSlice';
import sd2PipelineModelReducer from 'features/system/store/models/sd2PipelineModelSlice';
import { LOCALSTORAGE_PREFIX } from './constants'; import { LOCALSTORAGE_PREFIX } from './constants';
import { serialize } from './enhancers/reduxRemember/serialize'; import { serialize } from './enhancers/reduxRemember/serialize';
import { unserialize } from './enhancers/reduxRemember/unserialize'; import { unserialize } from './enhancers/reduxRemember/unserialize';
@ -43,8 +38,6 @@ const allReducers = {
gallery: galleryReducer, gallery: galleryReducer,
generation: generationReducer, generation: generationReducer,
lightbox: lightboxReducer, lightbox: lightboxReducer,
sd1pipelinemodels: sd1PipelineModelReducer,
sd2pipelinemodels: sd2PipelineModelReducer,
nodes: nodesReducer, nodes: nodesReducer,
postprocessing: postprocessingReducer, postprocessing: postprocessingReducer,
system: systemReducer, system: systemReducer,
@ -54,8 +47,8 @@ const allReducers = {
images: imagesReducer, images: imagesReducer,
controlNet: controlNetReducer, controlNet: controlNetReducer,
boards: boardsReducer, boards: boardsReducer,
[api.reducerPath]: api.reducer,
// session: sessionReducer, // session: sessionReducer,
[api.reducerPath]: api.reducer,
}; };
const rootReducer = combineReducers(allReducers); const rootReducer = combineReducers(allReducers);

View File

@ -1,14 +1,18 @@
import { NativeSelect } from '@mantine/core'; import { SelectItem } from '@mantine/core';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { import {
ModelInputFieldTemplate, ModelInputFieldTemplate,
ModelInputFieldValue, ModelInputFieldValue,
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { modelSelector } from 'features/system/store/modelSelectors'; import { memo, useCallback, useEffect, useMemo } from 'react';
import { ChangeEvent, memo } from 'react';
import { FieldComponentProps } from './types'; import { FieldComponentProps } from './types';
import { forEach, isString } from 'lodash-es';
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { useTranslation } from 'react-i18next';
import { useListModelsQuery } from 'services/apiSlice';
const ModelInputFieldComponent = ( const ModelInputFieldComponent = (
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate> props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
@ -16,26 +20,82 @@ const ModelInputFieldComponent = (
const { nodeId, field } = props; const { nodeId, field } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation();
const { sd1PipelineModelDropDownData, sd2PipelineModelDropdownData } = const { data: pipelineModels } = useListModelsQuery({
useAppSelector(modelSelector); model_type: 'pipeline',
});
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => { const data = useMemo(() => {
dispatch( if (!pipelineModels) {
fieldValueChanged({ return [];
nodeId, }
fieldName: field.name,
value: e.target.value, const data: SelectItem[] = [];
})
); forEach(pipelineModels.entities, (model, id) => {
}; if (!model) {
return;
}
data.push({
value: id,
label: model.name,
group: BASE_MODEL_NAME_MAP[model.base_model],
});
});
return data;
}, [pipelineModels]);
const selectedModel = useMemo(
() => pipelineModels?.entities[field.value ?? pipelineModels.ids[0]],
[pipelineModels?.entities, pipelineModels?.ids, field.value]
);
const handleValueChanged = useCallback(
(v: string | null) => {
if (!v) {
return;
}
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: v,
})
);
},
[dispatch, field.name, nodeId]
);
useEffect(() => {
if (field.value && pipelineModels?.ids.includes(field.value)) {
return;
}
const firstModel = pipelineModels?.ids[0];
if (!isString(firstModel)) {
return;
}
handleValueChanged(firstModel);
}, [field.value, handleValueChanged, pipelineModels?.ids]);
return ( return (
<NativeSelect <IAIMantineSelect
tooltip={selectedModel?.description}
label={
selectedModel?.base_model &&
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
}
value={field.value}
placeholder="Pick one"
data={data}
onChange={handleValueChanged} onChange={handleValueChanged}
value={field.value || sd1PipelineModelDropDownData[0].value} />
data={sd1PipelineModelDropDownData.concat(sd2PipelineModelDropdownData)}
></NativeSelect>
); );
}; };

View File

@ -101,21 +101,6 @@ const nodesSlice = createSlice({
builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => { builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => {
state.schema = action.payload; state.schema = action.payload;
}); });
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_url, thumbnail_url } = action.payload;
state.nodes.forEach((node) => {
forEach(node.data.inputs, (input) => {
if (input.type === 'image') {
if (input.value?.image_name === image_name) {
input.value.image_url = image_url;
input.value.thumbnail_url = thumbnail_url;
}
}
});
});
});
}, },
}); });

View File

@ -23,6 +23,7 @@ import {
} from './constants'; } from './constants';
import { set } from 'lodash-es'; import { set } from 'lodash-es';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
const moduleLog = log.child({ namespace: 'nodes' }); const moduleLog = log.child({ namespace: 'nodes' });
@ -36,7 +37,7 @@ export const buildCanvasImageToImageGraph = (
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: model_name, model: modelId,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -49,6 +50,8 @@ export const buildCanvasImageToImageGraph = (
// The bounding box determines width and height, not the width and height params // The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions; const { width, height } = state.canvas.boundingBoxDimensions;
const model = modelIdToPipelineModelField(modelId);
/** /**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node * full graph here as a template. Then use the parameters from app state and set friendlier node
@ -85,9 +88,9 @@ export const buildCanvasImageToImageGraph = (
id: NOISE, id: NOISE,
}, },
[MODEL_LOADER]: { [MODEL_LOADER]: {
type: 'sd1_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: MODEL_LOADER,
model_name, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',

View File

@ -17,6 +17,7 @@ import {
INPAINT_GRAPH, INPAINT_GRAPH,
INPAINT, INPAINT,
} from './constants'; } from './constants';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
const moduleLog = log.child({ namespace: 'nodes' }); const moduleLog = log.child({ namespace: 'nodes' });
@ -31,7 +32,7 @@ export const buildCanvasInpaintGraph = (
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: model_name, model: modelId,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -54,6 +55,8 @@ export const buildCanvasInpaintGraph = (
// We may need to set the inpaint width and height to scale the image // We may need to set the inpaint width and height to scale the image
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas; const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const model = modelIdToPipelineModelField(modelId);
const graph: NonNullableGraph = { const graph: NonNullableGraph = {
id: INPAINT_GRAPH, id: INPAINT_GRAPH,
nodes: { nodes: {
@ -99,9 +102,9 @@ export const buildCanvasInpaintGraph = (
prompt: negativePrompt, prompt: negativePrompt,
}, },
[MODEL_LOADER]: { [MODEL_LOADER]: {
type: 'sd1_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: MODEL_LOADER,
model_name, model,
}, },
[RANGE_OF_SIZE]: { [RANGE_OF_SIZE]: {
type: 'range_of_size', type: 'range_of_size',

View File

@ -14,6 +14,7 @@ import {
TEXT_TO_LATENTS, TEXT_TO_LATENTS,
} from './constants'; } from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
/** /**
* Builds the Canvas tab's Text to Image graph. * Builds the Canvas tab's Text to Image graph.
@ -24,7 +25,7 @@ export const buildCanvasTextToImageGraph = (
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: model_name, model: modelId,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -36,6 +37,8 @@ export const buildCanvasTextToImageGraph = (
// The bounding box determines width and height, not the width and height params // The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions; const { width, height } = state.canvas.boundingBoxDimensions;
const model = modelIdToPipelineModelField(modelId);
/** /**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node * full graph here as a template. Then use the parameters from app state and set friendlier node
@ -80,9 +83,9 @@ export const buildCanvasTextToImageGraph = (
steps, steps,
}, },
[MODEL_LOADER]: { [MODEL_LOADER]: {
type: 'sd1_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: MODEL_LOADER,
model_name, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',

View File

@ -22,6 +22,7 @@ import {
} from './constants'; } from './constants';
import { set } from 'lodash-es'; import { set } from 'lodash-es';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
const moduleLog = log.child({ namespace: 'nodes' }); const moduleLog = log.child({ namespace: 'nodes' });
@ -34,7 +35,7 @@ export const buildLinearImageToImageGraph = (
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: model_name, model: modelId,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -62,6 +63,8 @@ export const buildLinearImageToImageGraph = (
throw new Error('No initial image found in state'); throw new Error('No initial image found in state');
} }
const model = modelIdToPipelineModelField(modelId);
// copy-pasted graph from node editor, filled in with state values & friendly node ids // copy-pasted graph from node editor, filled in with state values & friendly node ids
const graph: NonNullableGraph = { const graph: NonNullableGraph = {
id: IMAGE_TO_IMAGE_GRAPH, id: IMAGE_TO_IMAGE_GRAPH,
@ -89,9 +92,9 @@ export const buildLinearImageToImageGraph = (
id: NOISE, id: NOISE,
}, },
[MODEL_LOADER]: { [MODEL_LOADER]: {
type: 'sd1_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: MODEL_LOADER,
model_name, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',

View File

@ -1,6 +1,10 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api'; import {
BaseModelType,
RandomIntInvocation,
RangeOfSizeInvocation,
} from 'services/api';
import { import {
ITERATE, ITERATE,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
@ -14,6 +18,7 @@ import {
TEXT_TO_LATENTS, TEXT_TO_LATENTS,
} from './constants'; } from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
type TextToImageGraphOverrides = { type TextToImageGraphOverrides = {
width: number; width: number;
@ -27,7 +32,7 @@ export const buildLinearTextToImageGraph = (
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: model_name, model: modelId,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -38,6 +43,8 @@ export const buildLinearTextToImageGraph = (
shouldRandomizeSeed, shouldRandomizeSeed,
} = state.generation; } = state.generation;
const model = modelIdToPipelineModelField(modelId);
/** /**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node * full graph here as a template. Then use the parameters from app state and set friendlier node
@ -82,9 +89,9 @@ export const buildLinearTextToImageGraph = (
steps, steps,
}, },
[MODEL_LOADER]: { [MODEL_LOADER]: {
type: 'sd1_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: MODEL_LOADER,
model_name, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',

View File

@ -1,9 +1,10 @@
import { Graph } from 'services/api'; import { Graph } from 'services/api';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
import { cloneDeep, forEach, omit, reduce, values } from 'lodash-es'; import { cloneDeep, omit, reduce } from 'lodash-es';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { InputFieldValue } from 'features/nodes/types/types'; import { InputFieldValue } from 'features/nodes/types/types';
import { AnyInvocation } from 'services/events/types'; import { AnyInvocation } from 'services/events/types';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
/** /**
* We need to do special handling for some fields * We need to do special handling for some fields
@ -24,6 +25,12 @@ export const parseFieldValue = (field: InputFieldValue) => {
} }
} }
if (field.type === 'model') {
if (field.value) {
return modelIdToPipelineModelField(field.value);
}
}
return field.value; return field.value;
}; };

View File

@ -7,7 +7,7 @@ export const NOISE = 'noise';
export const RANDOM_INT = 'rand_int'; export const RANDOM_INT = 'rand_int';
export const RANGE_OF_SIZE = 'range_of_size'; export const RANGE_OF_SIZE = 'range_of_size';
export const ITERATE = 'iterate'; export const ITERATE = 'iterate';
export const MODEL_LOADER = 'model_loader'; export const MODEL_LOADER = 'pipeline_model_loader';
export const IMAGE_TO_LATENTS = 'image_to_latents'; export const IMAGE_TO_LATENTS = 'image_to_latents';
export const LATENTS_TO_LATENTS = 'latents_to_latents'; export const LATENTS_TO_LATENTS = 'latents_to_latents';
export const RESIZE = 'resize_image'; export const RESIZE = 'resize_image';

View File

@ -0,0 +1,18 @@
import { BaseModelType, PipelineModelField } from 'services/api';
/**
* Crudely converts a model id to a pipeline model field
* TODO: Make better
*/
export const modelIdToPipelineModelField = (
modelId: string
): PipelineModelField => {
const [base_model, model_type, model_name] = modelId.split('/');
const field: PipelineModelField = {
base_model: base_model as BaseModelType,
model_name,
};
return field;
};

View File

@ -1,12 +1,9 @@
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import { DEFAULT_SCHEDULER_NAME, Scheduler } from 'app/constants'; import { DEFAULT_SCHEDULER_NAME } from 'app/constants';
import { ModelLoaderTypes } from 'features/system/components/ModelSelect';
import { configChanged } from 'features/system/store/configSlice'; import { configChanged } from 'features/system/store/configSlice';
import { clamp, sortBy } from 'lodash-es'; import { clamp } from 'lodash-es';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
import { imageUrlsReceived } from 'services/thunks/image';
import { receivedModels } from 'services/thunks/model';
import { import {
CfgScaleParam, CfgScaleParam,
HeightParam, HeightParam,
@ -50,7 +47,6 @@ export interface GenerationState {
horizontalSymmetrySteps: number; horizontalSymmetrySteps: number;
verticalSymmetrySteps: number; verticalSymmetrySteps: number;
model: ModelParam; model: ModelParam;
currentModelType: ModelLoaderTypes;
shouldUseSeamless: boolean; shouldUseSeamless: boolean;
seamlessXAxis: boolean; seamlessXAxis: boolean;
seamlessYAxis: boolean; seamlessYAxis: boolean;
@ -85,7 +81,6 @@ export const initialGenerationState: GenerationState = {
horizontalSymmetrySteps: 0, horizontalSymmetrySteps: 0,
verticalSymmetrySteps: 0, verticalSymmetrySteps: 0,
model: '', model: '',
currentModelType: 'sd1_model_loader',
shouldUseSeamless: false, shouldUseSeamless: false,
seamlessXAxis: true, seamlessXAxis: true,
seamlessYAxis: true, seamlessYAxis: true,
@ -221,33 +216,14 @@ export const generationSlice = createSlice({
modelSelected: (state, action: PayloadAction<string>) => { modelSelected: (state, action: PayloadAction<string>) => {
state.model = action.payload; state.model = action.payload;
}, },
setCurrentModelType: (state, action: PayloadAction<ModelLoaderTypes>) => {
state.currentModelType = action.payload;
},
}, },
extraReducers: (builder) => { extraReducers: (builder) => {
builder.addCase(receivedModels.fulfilled, (state, action) => {
if (!state.model) {
const firstModel = sortBy(action.payload, 'name')[0];
state.model = firstModel.name;
}
});
builder.addCase(configChanged, (state, action) => { builder.addCase(configChanged, (state, action) => {
const defaultModel = action.payload.sd?.defaultModel; const defaultModel = action.payload.sd?.defaultModel;
if (defaultModel && !state.model) { if (defaultModel && !state.model) {
state.model = defaultModel; state.model = defaultModel;
} }
}); });
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
// const { image_name, image_url, thumbnail_url } = action.payload;
// if (state.initialImage?.image_name === image_name) {
// state.initialImage.image_url = image_url;
// state.initialImage.thumbnail_url = thumbnail_url;
// }
// });
}, },
}); });
@ -284,7 +260,6 @@ export const {
setVerticalSymmetrySteps, setVerticalSymmetrySteps,
initialImageChanged, initialImageChanged,
modelSelected, modelSelected,
setCurrentModelType,
setShouldUseNoiseSettings, setShouldUseNoiseSettings,
setSeamless, setSeamless,
setSeamlessXAxis, setSeamlessXAxis,

View File

@ -154,3 +154,17 @@ export type StrengthParam = z.infer<typeof zStrength>;
*/ */
export const isValidStrength = (val: unknown): val is StrengthParam => export const isValidStrength = (val: unknown): val is StrengthParam =>
zStrength.safeParse(val).success; zStrength.safeParse(val).success;
// /**
// * Zod schema for BaseModelType
// */
// export const zBaseModelType = z.enum(['sd-1', 'sd-2']);
// /**
// * Type alias for base model type, inferred from its zod schema. Should be identical to the type alias from OpenAPI.
// */
// export type BaseModelType = z.infer<typeof zBaseModelType>;
// /**
// * Validates/type-guards a value as a base model type
// */
// export const isValidBaseModelType = (val: unknown): val is BaseModelType =>
// zBaseModelType.safeParse(val).success;

View File

@ -1,39 +1,58 @@
import { memo, useCallback, useEffect } from 'react'; import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { import { modelSelected } from 'features/parameters/store/generationSlice';
modelSelected,
setCurrentModelType,
} from 'features/parameters/store/generationSlice';
import { modelSelector } from '../store/modelSelectors'; import { forEach, isString } from 'lodash-es';
import { SelectItem } from '@mantine/core';
import { RootState } from 'app/store/store';
import { useListModelsQuery } from 'services/apiSlice';
export type ModelLoaderTypes = 'sd1_model_loader' | 'sd2_model_loader'; export const MODEL_TYPE_MAP = {
'sd-1': 'Stable Diffusion 1.x',
const MODEL_LOADER_MAP = { 'sd-2': 'Stable Diffusion 2.x',
'sd-1': 'sd1_model_loader',
'sd-2': 'sd2_model_loader',
}; };
const ModelSelect = () => { const ModelSelect = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const {
selectedModel,
sd1PipelineModelDropDownData,
sd2PipelineModelDropdownData,
} = useAppSelector(modelSelector);
useEffect(() => { const selectedModelId = useAppSelector(
if (selectedModel) (state: RootState) => state.generation.model
dispatch( );
setCurrentModelType(
MODEL_LOADER_MAP[selectedModel?.base_model] as ModelLoaderTypes const { data: pipelineModels } = useListModelsQuery({
) model_type: 'pipeline',
); });
}, [dispatch, selectedModel]);
const data = useMemo(() => {
if (!pipelineModels) {
return [];
}
const data: SelectItem[] = [];
forEach(pipelineModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [pipelineModels]);
const selectedModel = useMemo(
() => pipelineModels?.entities[selectedModelId],
[pipelineModels?.entities, selectedModelId]
);
const handleChangeModel = useCallback( const handleChangeModel = useCallback(
(v: string | null) => { (v: string | null) => {
@ -45,13 +64,27 @@ const ModelSelect = () => {
[dispatch] [dispatch]
); );
useEffect(() => {
if (selectedModelId && pipelineModels?.ids.includes(selectedModelId)) {
return;
}
const firstModel = pipelineModels?.ids[0];
if (!isString(firstModel)) {
return;
}
handleChangeModel(firstModel);
}, [handleChangeModel, pipelineModels?.ids, selectedModelId]);
return ( return (
<IAIMantineSelect <IAIMantineSelect
tooltip={selectedModel?.description} tooltip={selectedModel?.description}
label={t('modelManager.model')} label={t('modelManager.model')}
value={selectedModel?.name ?? ''} value={selectedModelId}
placeholder="Pick one" placeholder="Pick one"
data={sd1PipelineModelDropDownData.concat(sd2PipelineModelDropdownData)} data={data}
onChange={handleChangeModel} onChange={handleChangeModel}
/> />
); );

View File

@ -7,13 +7,12 @@ import { systemSelector } from '../store/systemSelectors';
const isApplicationReadySelector = createSelector( const isApplicationReadySelector = createSelector(
[systemSelector, configSelector], [systemSelector, configSelector],
(system, config) => { (system, config) => {
const { wereModelsReceived, wasSchemaParsed } = system; const { wasSchemaParsed } = system;
const { disabledTabs } = config; const { disabledTabs } = config;
return { return {
disabledTabs, disabledTabs,
wereModelsReceived,
wasSchemaParsed, wasSchemaParsed,
}; };
} }
@ -23,21 +22,17 @@ const isApplicationReadySelector = createSelector(
* Checks if the application is ready to be used, i.e. if the initial startup process is finished. * Checks if the application is ready to be used, i.e. if the initial startup process is finished.
*/ */
export const useIsApplicationReady = () => { export const useIsApplicationReady = () => {
const { disabledTabs, wereModelsReceived, wasSchemaParsed } = useAppSelector( const { disabledTabs, wasSchemaParsed } = useAppSelector(
isApplicationReadySelector isApplicationReadySelector
); );
const isApplicationReady = useMemo(() => { const isApplicationReady = useMemo(() => {
if (!wereModelsReceived) {
return false;
}
if (!disabledTabs.includes('nodes') && !wasSchemaParsed) { if (!disabledTabs.includes('nodes') && !wasSchemaParsed) {
return false; return false;
} }
return true; return true;
}, [disabledTabs, wereModelsReceived, wasSchemaParsed]); }, [disabledTabs, wasSchemaParsed]);
return isApplicationReady; return isApplicationReady;
}; };

View File

@ -1,59 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { IAISelectDataType } from 'common/components/IAIMantineSelect';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { isEqual } from 'lodash-es';
import {
selectAllSD1PipelineModels,
selectByIdSD1PipelineModels,
} from './models/sd1PipelineModelSlice';
import {
selectAllSD2PipelineModels,
selectByIdSD2PipelineModels,
} from './models/sd2PipelineModelSlice';
export const modelSelector = createSelector(
[(state: RootState) => state, generationSelector],
(state, generation) => {
let selectedModel = selectByIdSD1PipelineModels(state, generation.model);
if (selectedModel === undefined)
selectedModel = selectByIdSD2PipelineModels(state, generation.model);
const sd1PipelineModels = selectAllSD1PipelineModels(state);
const sd2PipelineModels = selectAllSD2PipelineModels(state);
const allPipelineModels = sd1PipelineModels.concat(sd2PipelineModels);
const sd1PipelineModelDropDownData = selectAllSD1PipelineModels(state)
.map<IAISelectDataType>((m) => ({
value: m.name,
label: m.name,
group: '1.x Models',
}))
.sort((a, b) => a.label.localeCompare(b.label));
const sd2PipelineModelDropdownData = selectAllSD2PipelineModels(state)
.map<IAISelectDataType>((m) => ({
value: m.name,
label: m.name,
group: '2.x Models',
}))
.sort((a, b) => a.label.localeCompare(b.label));
return {
selectedModel,
allPipelineModels,
sd1PipelineModels,
sd2PipelineModels,
sd1PipelineModelDropDownData,
sd2PipelineModelDropdownData,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);

View File

@ -1,56 +0,0 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import {
StableDiffusion1ModelCheckpointConfig,
StableDiffusion1ModelDiffusersConfig,
} from 'services/api';
import { receivedModels } from 'services/thunks/model';
export type SD1PipelineModel = (
| StableDiffusion1ModelCheckpointConfig
| StableDiffusion1ModelDiffusersConfig
) & {
name: string;
};
export const sd1PipelineModelsAdapter = createEntityAdapter<SD1PipelineModel>({
selectId: (model) => model.name,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const sd1InitialPipelineModelsState =
sd1PipelineModelsAdapter.getInitialState();
export type SD1PipelineModelState = typeof sd1InitialPipelineModelsState;
export const sd1PipelineModelsSlice = createSlice({
name: 'sd1PipelineModels',
initialState: sd1InitialPipelineModelsState,
reducers: {
modelAdded: sd1PipelineModelsAdapter.upsertOne,
},
extraReducers(builder) {
/**
* Received Models - FULFILLED
*/
builder.addCase(receivedModels.fulfilled, (state, action) => {
if (action.meta.arg.baseModel !== 'sd-1') return;
sd1PipelineModelsAdapter.setAll(state, action.payload);
});
},
});
export const {
selectAll: selectAllSD1PipelineModels,
selectById: selectByIdSD1PipelineModels,
selectEntities: selectEntitiesSD1PipelineModels,
selectIds: selectIdsSD1PipelineModels,
selectTotal: selectTotalSD1PipelineModels,
} = sd1PipelineModelsAdapter.getSelectors<RootState>(
(state) => state.sd1pipelinemodels
);
export const { modelAdded } = sd1PipelineModelsSlice.actions;
export default sd1PipelineModelsSlice.reducer;

View File

@ -1,56 +0,0 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import {
StableDiffusion2ModelCheckpointConfig,
StableDiffusion2ModelDiffusersConfig,
} from 'services/api';
import { receivedModels } from 'services/thunks/model';
export type SD2PipelineModel = (
| StableDiffusion2ModelCheckpointConfig
| StableDiffusion2ModelDiffusersConfig
) & {
name: string;
};
export const sd2PipelineModelsAdapater = createEntityAdapter<SD2PipelineModel>({
selectId: (model) => model.name,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const sd2InitialPipelineModelsState =
sd2PipelineModelsAdapater.getInitialState();
export type SD2PipelineModelState = typeof sd2InitialPipelineModelsState;
export const sd2PipelineModelsSlice = createSlice({
name: 'sd2PipelineModels',
initialState: sd2InitialPipelineModelsState,
reducers: {
modelAdded: sd2PipelineModelsAdapater.upsertOne,
},
extraReducers(builder) {
/**
* Received Models - FULFILLED
*/
builder.addCase(receivedModels.fulfilled, (state, action) => {
if (action.meta.arg.baseModel !== 'sd-2') return;
sd2PipelineModelsAdapater.setAll(state, action.payload);
});
},
});
export const {
selectAll: selectAllSD2PipelineModels,
selectById: selectByIdSD2PipelineModels,
selectEntities: selectEntitiesSD2PipelineModels,
selectIds: selectIdsSD2PipelineModels,
selectTotal: selectTotalSD2PipelineModels,
} = sd2PipelineModelsAdapater.getSelectors<RootState>(
(state) => state.sd2pipelinemodels
);
export const { modelAdded } = sd2PipelineModelsSlice.actions;
export default sd2PipelineModelsSlice.reducer;

View File

@ -1,9 +0,0 @@
import { SD1PipelineModelState } from './models/sd1PipelineModelSlice';
import { SD2PipelineModelState } from './models/sd2PipelineModelSlice';
/**
* Models slice persist denylist
*/
export const modelsPersistDenylist:
| (keyof SD1PipelineModelState)[]
| (keyof SD2PipelineModelState)[] = ['entities', 'ids'];

View File

@ -20,7 +20,6 @@ import {
} from 'services/events/actions'; } from 'services/events/actions';
import { ProgressImage } from 'services/events/types'; import { ProgressImage } from 'services/events/types';
import { imageUploaded } from 'services/thunks/image'; import { imageUploaded } from 'services/thunks/image';
import { receivedModels } from 'services/thunks/model';
import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session'; import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session';
import { makeToast } from '../../../app/components/Toaster'; import { makeToast } from '../../../app/components/Toaster';
import { LANGUAGES } from '../components/LanguagePicker'; import { LANGUAGES } from '../components/LanguagePicker';
@ -377,13 +376,6 @@ export const systemSlice = createSlice({
); );
}); });
/**
* Received available models from the backend
*/
builder.addCase(receivedModels.fulfilled, (state) => {
state.wereModelsReceived = true;
});
/** /**
* OpenAPI schema was parsed * OpenAPI schema was parsed
*/ */

View File

@ -13,23 +13,68 @@ import {
TagTypesFrom, TagTypesFrom,
TagTypesFromApi, TagTypesFromApi,
} from '@reduxjs/toolkit/dist/query/endpointDefinitions'; } from '@reduxjs/toolkit/dist/query/endpointDefinitions';
import { EntityState, createEntityAdapter } from '@reduxjs/toolkit';
import { BaseModelType } from './api/models/BaseModelType';
import { ModelType } from './api/models/ModelType';
import { ModelsList } from './api/models/ModelsList';
import { keyBy } from 'lodash-es';
type ListBoardsArg = { offset: number; limit: number }; type ListBoardsArg = { offset: number; limit: number };
type UpdateBoardArg = { board_id: string; changes: BoardChanges }; type UpdateBoardArg = { board_id: string; changes: BoardChanges };
type AddImageToBoardArg = { board_id: string; image_name: string }; type AddImageToBoardArg = { board_id: string; image_name: string };
type RemoveImageFromBoardArg = { board_id: string; image_name: string }; type RemoveImageFromBoardArg = { board_id: string; image_name: string };
type ListBoardImagesArg = { board_id: string; offset: number; limit: number }; type ListBoardImagesArg = { board_id: string; offset: number; limit: number };
type ListModelsArg = { base_model?: BaseModelType; model_type?: ModelType };
const tagTypes = ['Board', 'Image']; type ModelConfig = ModelsList['models'][number];
const tagTypes = ['Board', 'Image', 'Model'];
type ApiFullTagDescription = FullTagDescription<(typeof tagTypes)[number]>; type ApiFullTagDescription = FullTagDescription<(typeof tagTypes)[number]>;
const LIST = 'LIST'; const LIST = 'LIST';
const modelsAdapter = createEntityAdapter<ModelConfig>({
selectId: (model) => getModelId(model),
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
const getModelId = ({ base_model, type, name }: ModelConfig) =>
`${base_model}/${type}/${name}`;
export const api = createApi({ export const api = createApi({
baseQuery: fetchBaseQuery({ baseUrl: 'http://localhost:5173/api/v1/' }), baseQuery: fetchBaseQuery({ baseUrl: 'http://localhost:5173/api/v1/' }),
reducerPath: 'api', reducerPath: 'api',
tagTypes, tagTypes,
endpoints: (build) => ({ endpoints: (build) => ({
/**
* Models Queries
*/
listModels: build.query<EntityState<ModelConfig>, ListModelsArg>({
query: (arg) => ({ url: 'models/', params: arg }),
providesTags: (result, error, arg) => {
// any list of boards
const tags: ApiFullTagDescription[] = [{ id: 'Model', type: LIST }];
if (result) {
// and individual tags for each board
tags.push(
...result.ids.map((id) => ({
type: 'Model' as const,
id,
}))
);
}
return tags;
},
transformResponse: (response: ModelsList, meta, arg) => {
return modelsAdapter.addMany(
modelsAdapter.getInitialState(),
keyBy(response.models, getModelId)
);
},
}),
/** /**
* Boards Queries * Boards Queries
*/ */
@ -174,4 +219,5 @@ export const {
useRemoveImageFromBoardMutation, useRemoveImageFromBoardMutation,
useListBoardImagesQuery, useListBoardImagesQuery,
useGetImageDTOQuery, useGetImageDTOQuery,
useListModelsQuery,
} = api; } = api;

View File

@ -1,58 +0,0 @@
import { log } from 'app/logging/useLogger';
import { createAppAsyncThunk } from 'app/store/storeUtils';
import { SD1PipelineModel } from 'features/system/store/models/sd1PipelineModelSlice';
import { SD2PipelineModel } from 'features/system/store/models/sd2PipelineModelSlice';
import { reduce, size } from 'lodash-es';
import { BaseModelType, ModelType, ModelsService } from 'services/api';
const models = log.child({ namespace: 'model' });
export const IMAGES_PER_PAGE = 20;
type receivedModelsArg = {
baseModel: BaseModelType | undefined;
modelType: ModelType | undefined;
};
export const receivedModels = createAppAsyncThunk(
'models/receivedModels',
async (arg: receivedModelsArg) => {
const response = await ModelsService.listModels(arg);
let deserializedModels = {};
if (arg.baseModel === undefined) return response.models;
if (arg.modelType === undefined) return response.models;
if (arg.baseModel === 'sd-1') {
deserializedModels = reduce(
response.models[arg.baseModel][arg.modelType],
(modelsAccumulator, model, modelName) => {
modelsAccumulator[modelName] = { ...model, name: modelName };
return modelsAccumulator;
},
{} as Record<string, SD1PipelineModel>
);
}
if (arg.baseModel === 'sd-2') {
deserializedModels = reduce(
response.models[arg.baseModel][arg.modelType],
(modelsAccumulator, model, modelName) => {
modelsAccumulator[modelName] = { ...model, name: modelName };
return modelsAccumulator;
},
{} as Record<string, SD2PipelineModel>
);
}
models.info(
{ response },
`Received ${size(response.models[arg.baseModel][arg.modelType])} ${[
arg.baseModel,
]} models`
);
return deserializedModels;
}
);