mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): initial implementation of model loading
- Update model listing code to use `rtk-query` - Update all graph generation to use new `pipeline_model_loader` node
This commit is contained in:
parent
2a178f5a25
commit
339e7ce213
@ -24,6 +24,7 @@ import Toaster from './Toaster';
|
|||||||
import DeleteImageModal from 'features/gallery/components/DeleteImageModal';
|
import DeleteImageModal from 'features/gallery/components/DeleteImageModal';
|
||||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||||
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
|
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
|
||||||
|
import { useListModelsQuery } from 'services/apiSlice';
|
||||||
|
|
||||||
const DEFAULT_CONFIG = {};
|
const DEFAULT_CONFIG = {};
|
||||||
|
|
||||||
@ -46,6 +47,18 @@ const App = ({
|
|||||||
|
|
||||||
const isApplicationReady = useIsApplicationReady();
|
const isApplicationReady = useIsApplicationReady();
|
||||||
|
|
||||||
|
const { data: pipelineModels } = useListModelsQuery({
|
||||||
|
model_type: 'pipeline',
|
||||||
|
});
|
||||||
|
const { data: controlnetModels } = useListModelsQuery({
|
||||||
|
model_type: 'controlnet',
|
||||||
|
});
|
||||||
|
const { data: vaeModels } = useListModelsQuery({ model_type: 'vae' });
|
||||||
|
const { data: loraModels } = useListModelsQuery({ model_type: 'lora' });
|
||||||
|
const { data: embeddingModels } = useListModelsQuery({
|
||||||
|
model_type: 'embedding',
|
||||||
|
});
|
||||||
|
|
||||||
const [loadingOverridden, setLoadingOverridden] = useState(false);
|
const [loadingOverridden, setLoadingOverridden] = useState(false);
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
@ -5,7 +5,6 @@ import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersist
|
|||||||
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
|
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
|
||||||
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
|
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
|
||||||
import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
|
import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
|
||||||
import { modelsPersistDenylist } from 'features/system/store/modelsPersistDenylist';
|
|
||||||
import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist';
|
import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist';
|
||||||
import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist';
|
import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist';
|
||||||
import { omit } from 'lodash-es';
|
import { omit } from 'lodash-es';
|
||||||
@ -18,8 +17,6 @@ const serializationDenylist: {
|
|||||||
gallery: galleryPersistDenylist,
|
gallery: galleryPersistDenylist,
|
||||||
generation: generationPersistDenylist,
|
generation: generationPersistDenylist,
|
||||||
lightbox: lightboxPersistDenylist,
|
lightbox: lightboxPersistDenylist,
|
||||||
sd1models: modelsPersistDenylist,
|
|
||||||
sd2models: modelsPersistDenylist,
|
|
||||||
nodes: nodesPersistDenylist,
|
nodes: nodesPersistDenylist,
|
||||||
postprocessing: postprocessingPersistDenylist,
|
postprocessing: postprocessingPersistDenylist,
|
||||||
system: systemPersistDenylist,
|
system: systemPersistDenylist,
|
||||||
|
@ -7,8 +7,6 @@ import { initialNodesState } from 'features/nodes/store/nodesSlice';
|
|||||||
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
||||||
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
|
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
|
||||||
import { initialConfigState } from 'features/system/store/configSlice';
|
import { initialConfigState } from 'features/system/store/configSlice';
|
||||||
import { sd1InitialPipelineModelsState } from 'features/system/store/models/sd1PipelineModelSlice';
|
|
||||||
import { sd2InitialPipelineModelsState } from 'features/system/store/models/sd2PipelineModelSlice';
|
|
||||||
import { initialSystemState } from 'features/system/store/systemSlice';
|
import { initialSystemState } from 'features/system/store/systemSlice';
|
||||||
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
|
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
|
||||||
import { initialUIState } from 'features/ui/store/uiSlice';
|
import { initialUIState } from 'features/ui/store/uiSlice';
|
||||||
@ -22,8 +20,6 @@ const initialStates: {
|
|||||||
gallery: initialGalleryState,
|
gallery: initialGalleryState,
|
||||||
generation: initialGenerationState,
|
generation: initialGenerationState,
|
||||||
lightbox: initialLightboxState,
|
lightbox: initialLightboxState,
|
||||||
sd1PipelineModels: sd1InitialPipelineModelsState,
|
|
||||||
sd2PipelineModels: sd2InitialPipelineModelsState,
|
|
||||||
nodes: initialNodesState,
|
nodes: initialNodesState,
|
||||||
postprocessing: initialPostprocessingState,
|
postprocessing: initialPostprocessingState,
|
||||||
system: initialSystemState,
|
system: initialSystemState,
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { appSocketConnected, socketConnected } from 'services/events/actions';
|
import { appSocketConnected, socketConnected } from 'services/events/actions';
|
||||||
import { receivedPageOfImages } from 'services/thunks/image';
|
import { receivedPageOfImages } from 'services/thunks/image';
|
||||||
import { receivedModels } from 'services/thunks/model';
|
|
||||||
import { receivedOpenAPISchema } from 'services/thunks/schema';
|
import { receivedOpenAPISchema } from 'services/thunks/schema';
|
||||||
import { startAppListening } from '../..';
|
import { startAppListening } from '../..';
|
||||||
|
|
||||||
@ -15,8 +14,7 @@ export const addSocketConnectedEventListener = () => {
|
|||||||
|
|
||||||
moduleLog.debug({ timestamp }, 'Connected');
|
moduleLog.debug({ timestamp }, 'Connected');
|
||||||
|
|
||||||
const { sd1pipelinemodels, sd2pipelinemodels, nodes, config, images } =
|
const { nodes, config, images } = getState();
|
||||||
getState();
|
|
||||||
|
|
||||||
const { disabledTabs } = config;
|
const { disabledTabs } = config;
|
||||||
|
|
||||||
@ -29,14 +27,6 @@ export const addSocketConnectedEventListener = () => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!sd1pipelinemodels.ids.length) {
|
|
||||||
dispatch(receivedModels({ baseModel: 'sd-1', modelType: 'pipeline' }));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!sd2pipelinemodels.ids.length) {
|
|
||||||
dispatch(receivedModels({ baseModel: 'sd-2', modelType: 'pipeline' }));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!nodes.schema && !disabledTabs.includes('nodes')) {
|
if (!nodes.schema && !disabledTabs.includes('nodes')) {
|
||||||
dispatch(receivedOpenAPISchema());
|
dispatch(receivedOpenAPISchema());
|
||||||
}
|
}
|
||||||
|
@ -28,11 +28,6 @@ import { listenerMiddleware } from './middleware/listenerMiddleware';
|
|||||||
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
|
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
|
||||||
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
|
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
|
||||||
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
|
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
|
||||||
|
|
||||||
// Model Reducers
|
|
||||||
import sd1PipelineModelReducer from 'features/system/store/models/sd1PipelineModelSlice';
|
|
||||||
import sd2PipelineModelReducer from 'features/system/store/models/sd2PipelineModelSlice';
|
|
||||||
|
|
||||||
import { LOCALSTORAGE_PREFIX } from './constants';
|
import { LOCALSTORAGE_PREFIX } from './constants';
|
||||||
import { serialize } from './enhancers/reduxRemember/serialize';
|
import { serialize } from './enhancers/reduxRemember/serialize';
|
||||||
import { unserialize } from './enhancers/reduxRemember/unserialize';
|
import { unserialize } from './enhancers/reduxRemember/unserialize';
|
||||||
@ -43,8 +38,6 @@ const allReducers = {
|
|||||||
gallery: galleryReducer,
|
gallery: galleryReducer,
|
||||||
generation: generationReducer,
|
generation: generationReducer,
|
||||||
lightbox: lightboxReducer,
|
lightbox: lightboxReducer,
|
||||||
sd1pipelinemodels: sd1PipelineModelReducer,
|
|
||||||
sd2pipelinemodels: sd2PipelineModelReducer,
|
|
||||||
nodes: nodesReducer,
|
nodes: nodesReducer,
|
||||||
postprocessing: postprocessingReducer,
|
postprocessing: postprocessingReducer,
|
||||||
system: systemReducer,
|
system: systemReducer,
|
||||||
@ -54,8 +47,8 @@ const allReducers = {
|
|||||||
images: imagesReducer,
|
images: imagesReducer,
|
||||||
controlNet: controlNetReducer,
|
controlNet: controlNetReducer,
|
||||||
boards: boardsReducer,
|
boards: boardsReducer,
|
||||||
[api.reducerPath]: api.reducer,
|
|
||||||
// session: sessionReducer,
|
// session: sessionReducer,
|
||||||
|
[api.reducerPath]: api.reducer,
|
||||||
};
|
};
|
||||||
|
|
||||||
const rootReducer = combineReducers(allReducers);
|
const rootReducer = combineReducers(allReducers);
|
||||||
|
@ -1,14 +1,18 @@
|
|||||||
import { NativeSelect } from '@mantine/core';
|
import { SelectItem } from '@mantine/core';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import {
|
import {
|
||||||
ModelInputFieldTemplate,
|
ModelInputFieldTemplate,
|
||||||
ModelInputFieldValue,
|
ModelInputFieldValue,
|
||||||
} from 'features/nodes/types/types';
|
} from 'features/nodes/types/types';
|
||||||
|
|
||||||
import { modelSelector } from 'features/system/store/modelSelectors';
|
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||||
import { ChangeEvent, memo } from 'react';
|
|
||||||
import { FieldComponentProps } from './types';
|
import { FieldComponentProps } from './types';
|
||||||
|
import { forEach, isString } from 'lodash-es';
|
||||||
|
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
||||||
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useListModelsQuery } from 'services/apiSlice';
|
||||||
|
|
||||||
const ModelInputFieldComponent = (
|
const ModelInputFieldComponent = (
|
||||||
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
|
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
|
||||||
@ -16,26 +20,82 @@ const ModelInputFieldComponent = (
|
|||||||
const { nodeId, field } = props;
|
const { nodeId, field } = props;
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const { sd1PipelineModelDropDownData, sd2PipelineModelDropdownData } =
|
const { data: pipelineModels } = useListModelsQuery({
|
||||||
useAppSelector(modelSelector);
|
model_type: 'pipeline',
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = useMemo(() => {
|
||||||
|
if (!pipelineModels) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
|
forEach(pipelineModels.entities, (model, id) => {
|
||||||
|
if (!model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
data.push({
|
||||||
|
value: id,
|
||||||
|
label: model.name,
|
||||||
|
group: BASE_MODEL_NAME_MAP[model.base_model],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}, [pipelineModels]);
|
||||||
|
|
||||||
|
const selectedModel = useMemo(
|
||||||
|
() => pipelineModels?.entities[field.value ?? pipelineModels.ids[0]],
|
||||||
|
[pipelineModels?.entities, pipelineModels?.ids, field.value]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleValueChanged = useCallback(
|
||||||
|
(v: string | null) => {
|
||||||
|
if (!v) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
|
|
||||||
dispatch(
|
dispatch(
|
||||||
fieldValueChanged({
|
fieldValueChanged({
|
||||||
nodeId,
|
nodeId,
|
||||||
fieldName: field.name,
|
fieldName: field.name,
|
||||||
value: e.target.value,
|
value: v,
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
};
|
},
|
||||||
|
[dispatch, field.name, nodeId]
|
||||||
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (field.value && pipelineModels?.ids.includes(field.value)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const firstModel = pipelineModels?.ids[0];
|
||||||
|
|
||||||
|
if (!isString(firstModel)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
handleValueChanged(firstModel);
|
||||||
|
}, [field.value, handleValueChanged, pipelineModels?.ids]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<NativeSelect
|
<IAIMantineSelect
|
||||||
|
tooltip={selectedModel?.description}
|
||||||
|
label={
|
||||||
|
selectedModel?.base_model &&
|
||||||
|
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
|
||||||
|
}
|
||||||
|
value={field.value}
|
||||||
|
placeholder="Pick one"
|
||||||
|
data={data}
|
||||||
onChange={handleValueChanged}
|
onChange={handleValueChanged}
|
||||||
value={field.value || sd1PipelineModelDropDownData[0].value}
|
/>
|
||||||
data={sd1PipelineModelDropDownData.concat(sd2PipelineModelDropdownData)}
|
|
||||||
></NativeSelect>
|
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -101,21 +101,6 @@ const nodesSlice = createSlice({
|
|||||||
builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => {
|
builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => {
|
||||||
state.schema = action.payload;
|
state.schema = action.payload;
|
||||||
});
|
});
|
||||||
|
|
||||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
|
||||||
const { image_name, image_url, thumbnail_url } = action.payload;
|
|
||||||
|
|
||||||
state.nodes.forEach((node) => {
|
|
||||||
forEach(node.data.inputs, (input) => {
|
|
||||||
if (input.type === 'image') {
|
|
||||||
if (input.value?.image_name === image_name) {
|
|
||||||
input.value.image_url = image_url;
|
|
||||||
input.value.thumbnail_url = thumbnail_url;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ import {
|
|||||||
} from './constants';
|
} from './constants';
|
||||||
import { set } from 'lodash-es';
|
import { set } from 'lodash-es';
|
||||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||||
|
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'nodes' });
|
const moduleLog = log.child({ namespace: 'nodes' });
|
||||||
|
|
||||||
@ -36,7 +37,7 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
negativePrompt,
|
negativePrompt,
|
||||||
model: model_name,
|
model: modelId,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
@ -49,6 +50,8 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
// The bounding box determines width and height, not the width and height params
|
// The bounding box determines width and height, not the width and height params
|
||||||
const { width, height } = state.canvas.boundingBoxDimensions;
|
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||||
|
|
||||||
|
const model = modelIdToPipelineModelField(modelId);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||||
@ -85,9 +88,9 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
id: NOISE,
|
id: NOISE,
|
||||||
},
|
},
|
||||||
[MODEL_LOADER]: {
|
[MODEL_LOADER]: {
|
||||||
type: 'sd1_model_loader',
|
type: 'pipeline_model_loader',
|
||||||
id: MODEL_LOADER,
|
id: MODEL_LOADER,
|
||||||
model_name,
|
model,
|
||||||
},
|
},
|
||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
|
@ -17,6 +17,7 @@ import {
|
|||||||
INPAINT_GRAPH,
|
INPAINT_GRAPH,
|
||||||
INPAINT,
|
INPAINT,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
|
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'nodes' });
|
const moduleLog = log.child({ namespace: 'nodes' });
|
||||||
|
|
||||||
@ -31,7 +32,7 @@ export const buildCanvasInpaintGraph = (
|
|||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
negativePrompt,
|
negativePrompt,
|
||||||
model: model_name,
|
model: modelId,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
@ -54,6 +55,8 @@ export const buildCanvasInpaintGraph = (
|
|||||||
// We may need to set the inpaint width and height to scale the image
|
// We may need to set the inpaint width and height to scale the image
|
||||||
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
|
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
|
||||||
|
|
||||||
|
const model = modelIdToPipelineModelField(modelId);
|
||||||
|
|
||||||
const graph: NonNullableGraph = {
|
const graph: NonNullableGraph = {
|
||||||
id: INPAINT_GRAPH,
|
id: INPAINT_GRAPH,
|
||||||
nodes: {
|
nodes: {
|
||||||
@ -99,9 +102,9 @@ export const buildCanvasInpaintGraph = (
|
|||||||
prompt: negativePrompt,
|
prompt: negativePrompt,
|
||||||
},
|
},
|
||||||
[MODEL_LOADER]: {
|
[MODEL_LOADER]: {
|
||||||
type: 'sd1_model_loader',
|
type: 'pipeline_model_loader',
|
||||||
id: MODEL_LOADER,
|
id: MODEL_LOADER,
|
||||||
model_name,
|
model,
|
||||||
},
|
},
|
||||||
[RANGE_OF_SIZE]: {
|
[RANGE_OF_SIZE]: {
|
||||||
type: 'range_of_size',
|
type: 'range_of_size',
|
||||||
|
@ -14,6 +14,7 @@ import {
|
|||||||
TEXT_TO_LATENTS,
|
TEXT_TO_LATENTS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||||
|
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds the Canvas tab's Text to Image graph.
|
* Builds the Canvas tab's Text to Image graph.
|
||||||
@ -24,7 +25,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
negativePrompt,
|
negativePrompt,
|
||||||
model: model_name,
|
model: modelId,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
@ -36,6 +37,8 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
// The bounding box determines width and height, not the width and height params
|
// The bounding box determines width and height, not the width and height params
|
||||||
const { width, height } = state.canvas.boundingBoxDimensions;
|
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||||
|
|
||||||
|
const model = modelIdToPipelineModelField(modelId);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||||
@ -80,9 +83,9 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
steps,
|
steps,
|
||||||
},
|
},
|
||||||
[MODEL_LOADER]: {
|
[MODEL_LOADER]: {
|
||||||
type: 'sd1_model_loader',
|
type: 'pipeline_model_loader',
|
||||||
id: MODEL_LOADER,
|
id: MODEL_LOADER,
|
||||||
model_name,
|
model,
|
||||||
},
|
},
|
||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
|
@ -22,6 +22,7 @@ import {
|
|||||||
} from './constants';
|
} from './constants';
|
||||||
import { set } from 'lodash-es';
|
import { set } from 'lodash-es';
|
||||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||||
|
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'nodes' });
|
const moduleLog = log.child({ namespace: 'nodes' });
|
||||||
|
|
||||||
@ -34,7 +35,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
negativePrompt,
|
negativePrompt,
|
||||||
model: model_name,
|
model: modelId,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
@ -62,6 +63,8 @@ export const buildLinearImageToImageGraph = (
|
|||||||
throw new Error('No initial image found in state');
|
throw new Error('No initial image found in state');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const model = modelIdToPipelineModelField(modelId);
|
||||||
|
|
||||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||||
const graph: NonNullableGraph = {
|
const graph: NonNullableGraph = {
|
||||||
id: IMAGE_TO_IMAGE_GRAPH,
|
id: IMAGE_TO_IMAGE_GRAPH,
|
||||||
@ -89,9 +92,9 @@ export const buildLinearImageToImageGraph = (
|
|||||||
id: NOISE,
|
id: NOISE,
|
||||||
},
|
},
|
||||||
[MODEL_LOADER]: {
|
[MODEL_LOADER]: {
|
||||||
type: 'sd1_model_loader',
|
type: 'pipeline_model_loader',
|
||||||
id: MODEL_LOADER,
|
id: MODEL_LOADER,
|
||||||
model_name,
|
model,
|
||||||
},
|
},
|
||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api';
|
import {
|
||||||
|
BaseModelType,
|
||||||
|
RandomIntInvocation,
|
||||||
|
RangeOfSizeInvocation,
|
||||||
|
} from 'services/api';
|
||||||
import {
|
import {
|
||||||
ITERATE,
|
ITERATE,
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
@ -14,6 +18,7 @@ import {
|
|||||||
TEXT_TO_LATENTS,
|
TEXT_TO_LATENTS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||||
|
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||||
|
|
||||||
type TextToImageGraphOverrides = {
|
type TextToImageGraphOverrides = {
|
||||||
width: number;
|
width: number;
|
||||||
@ -27,7 +32,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
negativePrompt,
|
negativePrompt,
|
||||||
model: model_name,
|
model: modelId,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
@ -38,6 +43,8 @@ export const buildLinearTextToImageGraph = (
|
|||||||
shouldRandomizeSeed,
|
shouldRandomizeSeed,
|
||||||
} = state.generation;
|
} = state.generation;
|
||||||
|
|
||||||
|
const model = modelIdToPipelineModelField(modelId);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||||
@ -82,9 +89,9 @@ export const buildLinearTextToImageGraph = (
|
|||||||
steps,
|
steps,
|
||||||
},
|
},
|
||||||
[MODEL_LOADER]: {
|
[MODEL_LOADER]: {
|
||||||
type: 'sd1_model_loader',
|
type: 'pipeline_model_loader',
|
||||||
id: MODEL_LOADER,
|
id: MODEL_LOADER,
|
||||||
model_name,
|
model,
|
||||||
},
|
},
|
||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
import { Graph } from 'services/api';
|
import { Graph } from 'services/api';
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
import { cloneDeep, forEach, omit, reduce, values } from 'lodash-es';
|
import { cloneDeep, omit, reduce } from 'lodash-es';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { InputFieldValue } from 'features/nodes/types/types';
|
import { InputFieldValue } from 'features/nodes/types/types';
|
||||||
import { AnyInvocation } from 'services/events/types';
|
import { AnyInvocation } from 'services/events/types';
|
||||||
|
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* We need to do special handling for some fields
|
* We need to do special handling for some fields
|
||||||
@ -24,6 +25,12 @@ export const parseFieldValue = (field: InputFieldValue) => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (field.type === 'model') {
|
||||||
|
if (field.value) {
|
||||||
|
return modelIdToPipelineModelField(field.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return field.value;
|
return field.value;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ export const NOISE = 'noise';
|
|||||||
export const RANDOM_INT = 'rand_int';
|
export const RANDOM_INT = 'rand_int';
|
||||||
export const RANGE_OF_SIZE = 'range_of_size';
|
export const RANGE_OF_SIZE = 'range_of_size';
|
||||||
export const ITERATE = 'iterate';
|
export const ITERATE = 'iterate';
|
||||||
export const MODEL_LOADER = 'model_loader';
|
export const MODEL_LOADER = 'pipeline_model_loader';
|
||||||
export const IMAGE_TO_LATENTS = 'image_to_latents';
|
export const IMAGE_TO_LATENTS = 'image_to_latents';
|
||||||
export const LATENTS_TO_LATENTS = 'latents_to_latents';
|
export const LATENTS_TO_LATENTS = 'latents_to_latents';
|
||||||
export const RESIZE = 'resize_image';
|
export const RESIZE = 'resize_image';
|
||||||
|
@ -0,0 +1,18 @@
|
|||||||
|
import { BaseModelType, PipelineModelField } from 'services/api';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Crudely converts a model id to a pipeline model field
|
||||||
|
* TODO: Make better
|
||||||
|
*/
|
||||||
|
export const modelIdToPipelineModelField = (
|
||||||
|
modelId: string
|
||||||
|
): PipelineModelField => {
|
||||||
|
const [base_model, model_type, model_name] = modelId.split('/');
|
||||||
|
|
||||||
|
const field: PipelineModelField = {
|
||||||
|
base_model: base_model as BaseModelType,
|
||||||
|
model_name,
|
||||||
|
};
|
||||||
|
|
||||||
|
return field;
|
||||||
|
};
|
@ -1,12 +1,9 @@
|
|||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
import { DEFAULT_SCHEDULER_NAME, Scheduler } from 'app/constants';
|
import { DEFAULT_SCHEDULER_NAME } from 'app/constants';
|
||||||
import { ModelLoaderTypes } from 'features/system/components/ModelSelect';
|
|
||||||
import { configChanged } from 'features/system/store/configSlice';
|
import { configChanged } from 'features/system/store/configSlice';
|
||||||
import { clamp, sortBy } from 'lodash-es';
|
import { clamp } from 'lodash-es';
|
||||||
import { ImageDTO } from 'services/api';
|
import { ImageDTO } from 'services/api';
|
||||||
import { imageUrlsReceived } from 'services/thunks/image';
|
|
||||||
import { receivedModels } from 'services/thunks/model';
|
|
||||||
import {
|
import {
|
||||||
CfgScaleParam,
|
CfgScaleParam,
|
||||||
HeightParam,
|
HeightParam,
|
||||||
@ -50,7 +47,6 @@ export interface GenerationState {
|
|||||||
horizontalSymmetrySteps: number;
|
horizontalSymmetrySteps: number;
|
||||||
verticalSymmetrySteps: number;
|
verticalSymmetrySteps: number;
|
||||||
model: ModelParam;
|
model: ModelParam;
|
||||||
currentModelType: ModelLoaderTypes;
|
|
||||||
shouldUseSeamless: boolean;
|
shouldUseSeamless: boolean;
|
||||||
seamlessXAxis: boolean;
|
seamlessXAxis: boolean;
|
||||||
seamlessYAxis: boolean;
|
seamlessYAxis: boolean;
|
||||||
@ -85,7 +81,6 @@ export const initialGenerationState: GenerationState = {
|
|||||||
horizontalSymmetrySteps: 0,
|
horizontalSymmetrySteps: 0,
|
||||||
verticalSymmetrySteps: 0,
|
verticalSymmetrySteps: 0,
|
||||||
model: '',
|
model: '',
|
||||||
currentModelType: 'sd1_model_loader',
|
|
||||||
shouldUseSeamless: false,
|
shouldUseSeamless: false,
|
||||||
seamlessXAxis: true,
|
seamlessXAxis: true,
|
||||||
seamlessYAxis: true,
|
seamlessYAxis: true,
|
||||||
@ -221,33 +216,14 @@ export const generationSlice = createSlice({
|
|||||||
modelSelected: (state, action: PayloadAction<string>) => {
|
modelSelected: (state, action: PayloadAction<string>) => {
|
||||||
state.model = action.payload;
|
state.model = action.payload;
|
||||||
},
|
},
|
||||||
setCurrentModelType: (state, action: PayloadAction<ModelLoaderTypes>) => {
|
|
||||||
state.currentModelType = action.payload;
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
extraReducers: (builder) => {
|
extraReducers: (builder) => {
|
||||||
builder.addCase(receivedModels.fulfilled, (state, action) => {
|
|
||||||
if (!state.model) {
|
|
||||||
const firstModel = sortBy(action.payload, 'name')[0];
|
|
||||||
state.model = firstModel.name;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
builder.addCase(configChanged, (state, action) => {
|
builder.addCase(configChanged, (state, action) => {
|
||||||
const defaultModel = action.payload.sd?.defaultModel;
|
const defaultModel = action.payload.sd?.defaultModel;
|
||||||
if (defaultModel && !state.model) {
|
if (defaultModel && !state.model) {
|
||||||
state.model = defaultModel;
|
state.model = defaultModel;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
|
||||||
// const { image_name, image_url, thumbnail_url } = action.payload;
|
|
||||||
|
|
||||||
// if (state.initialImage?.image_name === image_name) {
|
|
||||||
// state.initialImage.image_url = image_url;
|
|
||||||
// state.initialImage.thumbnail_url = thumbnail_url;
|
|
||||||
// }
|
|
||||||
// });
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -284,7 +260,6 @@ export const {
|
|||||||
setVerticalSymmetrySteps,
|
setVerticalSymmetrySteps,
|
||||||
initialImageChanged,
|
initialImageChanged,
|
||||||
modelSelected,
|
modelSelected,
|
||||||
setCurrentModelType,
|
|
||||||
setShouldUseNoiseSettings,
|
setShouldUseNoiseSettings,
|
||||||
setSeamless,
|
setSeamless,
|
||||||
setSeamlessXAxis,
|
setSeamlessXAxis,
|
||||||
|
@ -154,3 +154,17 @@ export type StrengthParam = z.infer<typeof zStrength>;
|
|||||||
*/
|
*/
|
||||||
export const isValidStrength = (val: unknown): val is StrengthParam =>
|
export const isValidStrength = (val: unknown): val is StrengthParam =>
|
||||||
zStrength.safeParse(val).success;
|
zStrength.safeParse(val).success;
|
||||||
|
|
||||||
|
// /**
|
||||||
|
// * Zod schema for BaseModelType
|
||||||
|
// */
|
||||||
|
// export const zBaseModelType = z.enum(['sd-1', 'sd-2']);
|
||||||
|
// /**
|
||||||
|
// * Type alias for base model type, inferred from its zod schema. Should be identical to the type alias from OpenAPI.
|
||||||
|
// */
|
||||||
|
// export type BaseModelType = z.infer<typeof zBaseModelType>;
|
||||||
|
// /**
|
||||||
|
// * Validates/type-guards a value as a base model type
|
||||||
|
// */
|
||||||
|
// export const isValidBaseModelType = (val: unknown): val is BaseModelType =>
|
||||||
|
// zBaseModelType.safeParse(val).success;
|
||||||
|
@ -1,39 +1,58 @@
|
|||||||
import { memo, useCallback, useEffect } from 'react';
|
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
import {
|
import { modelSelected } from 'features/parameters/store/generationSlice';
|
||||||
modelSelected,
|
|
||||||
setCurrentModelType,
|
|
||||||
} from 'features/parameters/store/generationSlice';
|
|
||||||
|
|
||||||
import { modelSelector } from '../store/modelSelectors';
|
import { forEach, isString } from 'lodash-es';
|
||||||
|
import { SelectItem } from '@mantine/core';
|
||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { useListModelsQuery } from 'services/apiSlice';
|
||||||
|
|
||||||
export type ModelLoaderTypes = 'sd1_model_loader' | 'sd2_model_loader';
|
export const MODEL_TYPE_MAP = {
|
||||||
|
'sd-1': 'Stable Diffusion 1.x',
|
||||||
const MODEL_LOADER_MAP = {
|
'sd-2': 'Stable Diffusion 2.x',
|
||||||
'sd-1': 'sd1_model_loader',
|
|
||||||
'sd-2': 'sd2_model_loader',
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const ModelSelect = () => {
|
const ModelSelect = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const {
|
|
||||||
selectedModel,
|
|
||||||
sd1PipelineModelDropDownData,
|
|
||||||
sd2PipelineModelDropdownData,
|
|
||||||
} = useAppSelector(modelSelector);
|
|
||||||
|
|
||||||
useEffect(() => {
|
const selectedModelId = useAppSelector(
|
||||||
if (selectedModel)
|
(state: RootState) => state.generation.model
|
||||||
dispatch(
|
);
|
||||||
setCurrentModelType(
|
|
||||||
MODEL_LOADER_MAP[selectedModel?.base_model] as ModelLoaderTypes
|
const { data: pipelineModels } = useListModelsQuery({
|
||||||
)
|
model_type: 'pipeline',
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = useMemo(() => {
|
||||||
|
if (!pipelineModels) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
|
forEach(pipelineModels.entities, (model, id) => {
|
||||||
|
if (!model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
data.push({
|
||||||
|
value: id,
|
||||||
|
label: model.name,
|
||||||
|
group: MODEL_TYPE_MAP[model.base_model],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}, [pipelineModels]);
|
||||||
|
|
||||||
|
const selectedModel = useMemo(
|
||||||
|
() => pipelineModels?.entities[selectedModelId],
|
||||||
|
[pipelineModels?.entities, selectedModelId]
|
||||||
);
|
);
|
||||||
}, [dispatch, selectedModel]);
|
|
||||||
|
|
||||||
const handleChangeModel = useCallback(
|
const handleChangeModel = useCallback(
|
||||||
(v: string | null) => {
|
(v: string | null) => {
|
||||||
@ -45,13 +64,27 @@ const ModelSelect = () => {
|
|||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (selectedModelId && pipelineModels?.ids.includes(selectedModelId)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const firstModel = pipelineModels?.ids[0];
|
||||||
|
|
||||||
|
if (!isString(firstModel)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
handleChangeModel(firstModel);
|
||||||
|
}, [handleChangeModel, pipelineModels?.ids, selectedModelId]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAIMantineSelect
|
<IAIMantineSelect
|
||||||
tooltip={selectedModel?.description}
|
tooltip={selectedModel?.description}
|
||||||
label={t('modelManager.model')}
|
label={t('modelManager.model')}
|
||||||
value={selectedModel?.name ?? ''}
|
value={selectedModelId}
|
||||||
placeholder="Pick one"
|
placeholder="Pick one"
|
||||||
data={sd1PipelineModelDropDownData.concat(sd2PipelineModelDropdownData)}
|
data={data}
|
||||||
onChange={handleChangeModel}
|
onChange={handleChangeModel}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
|
@ -7,13 +7,12 @@ import { systemSelector } from '../store/systemSelectors';
|
|||||||
const isApplicationReadySelector = createSelector(
|
const isApplicationReadySelector = createSelector(
|
||||||
[systemSelector, configSelector],
|
[systemSelector, configSelector],
|
||||||
(system, config) => {
|
(system, config) => {
|
||||||
const { wereModelsReceived, wasSchemaParsed } = system;
|
const { wasSchemaParsed } = system;
|
||||||
|
|
||||||
const { disabledTabs } = config;
|
const { disabledTabs } = config;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
disabledTabs,
|
disabledTabs,
|
||||||
wereModelsReceived,
|
|
||||||
wasSchemaParsed,
|
wasSchemaParsed,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -23,21 +22,17 @@ const isApplicationReadySelector = createSelector(
|
|||||||
* Checks if the application is ready to be used, i.e. if the initial startup process is finished.
|
* Checks if the application is ready to be used, i.e. if the initial startup process is finished.
|
||||||
*/
|
*/
|
||||||
export const useIsApplicationReady = () => {
|
export const useIsApplicationReady = () => {
|
||||||
const { disabledTabs, wereModelsReceived, wasSchemaParsed } = useAppSelector(
|
const { disabledTabs, wasSchemaParsed } = useAppSelector(
|
||||||
isApplicationReadySelector
|
isApplicationReadySelector
|
||||||
);
|
);
|
||||||
|
|
||||||
const isApplicationReady = useMemo(() => {
|
const isApplicationReady = useMemo(() => {
|
||||||
if (!wereModelsReceived) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!disabledTabs.includes('nodes') && !wasSchemaParsed) {
|
if (!disabledTabs.includes('nodes') && !wasSchemaParsed) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}, [disabledTabs, wereModelsReceived, wasSchemaParsed]);
|
}, [disabledTabs, wasSchemaParsed]);
|
||||||
|
|
||||||
return isApplicationReady;
|
return isApplicationReady;
|
||||||
};
|
};
|
||||||
|
@ -1,59 +0,0 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { IAISelectDataType } from 'common/components/IAIMantineSelect';
|
|
||||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
|
||||||
import { isEqual } from 'lodash-es';
|
|
||||||
|
|
||||||
import {
|
|
||||||
selectAllSD1PipelineModels,
|
|
||||||
selectByIdSD1PipelineModels,
|
|
||||||
} from './models/sd1PipelineModelSlice';
|
|
||||||
|
|
||||||
import {
|
|
||||||
selectAllSD2PipelineModels,
|
|
||||||
selectByIdSD2PipelineModels,
|
|
||||||
} from './models/sd2PipelineModelSlice';
|
|
||||||
|
|
||||||
export const modelSelector = createSelector(
|
|
||||||
[(state: RootState) => state, generationSelector],
|
|
||||||
(state, generation) => {
|
|
||||||
let selectedModel = selectByIdSD1PipelineModels(state, generation.model);
|
|
||||||
if (selectedModel === undefined)
|
|
||||||
selectedModel = selectByIdSD2PipelineModels(state, generation.model);
|
|
||||||
|
|
||||||
const sd1PipelineModels = selectAllSD1PipelineModels(state);
|
|
||||||
const sd2PipelineModels = selectAllSD2PipelineModels(state);
|
|
||||||
|
|
||||||
const allPipelineModels = sd1PipelineModels.concat(sd2PipelineModels);
|
|
||||||
|
|
||||||
const sd1PipelineModelDropDownData = selectAllSD1PipelineModels(state)
|
|
||||||
.map<IAISelectDataType>((m) => ({
|
|
||||||
value: m.name,
|
|
||||||
label: m.name,
|
|
||||||
group: '1.x Models',
|
|
||||||
}))
|
|
||||||
.sort((a, b) => a.label.localeCompare(b.label));
|
|
||||||
|
|
||||||
const sd2PipelineModelDropdownData = selectAllSD2PipelineModels(state)
|
|
||||||
.map<IAISelectDataType>((m) => ({
|
|
||||||
value: m.name,
|
|
||||||
label: m.name,
|
|
||||||
group: '2.x Models',
|
|
||||||
}))
|
|
||||||
.sort((a, b) => a.label.localeCompare(b.label));
|
|
||||||
|
|
||||||
return {
|
|
||||||
selectedModel,
|
|
||||||
allPipelineModels,
|
|
||||||
sd1PipelineModels,
|
|
||||||
sd2PipelineModels,
|
|
||||||
sd1PipelineModelDropDownData,
|
|
||||||
sd2PipelineModelDropdownData,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
@ -1,56 +0,0 @@
|
|||||||
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import {
|
|
||||||
StableDiffusion1ModelCheckpointConfig,
|
|
||||||
StableDiffusion1ModelDiffusersConfig,
|
|
||||||
} from 'services/api';
|
|
||||||
|
|
||||||
import { receivedModels } from 'services/thunks/model';
|
|
||||||
|
|
||||||
export type SD1PipelineModel = (
|
|
||||||
| StableDiffusion1ModelCheckpointConfig
|
|
||||||
| StableDiffusion1ModelDiffusersConfig
|
|
||||||
) & {
|
|
||||||
name: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const sd1PipelineModelsAdapter = createEntityAdapter<SD1PipelineModel>({
|
|
||||||
selectId: (model) => model.name,
|
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
|
||||||
});
|
|
||||||
|
|
||||||
export const sd1InitialPipelineModelsState =
|
|
||||||
sd1PipelineModelsAdapter.getInitialState();
|
|
||||||
|
|
||||||
export type SD1PipelineModelState = typeof sd1InitialPipelineModelsState;
|
|
||||||
|
|
||||||
export const sd1PipelineModelsSlice = createSlice({
|
|
||||||
name: 'sd1PipelineModels',
|
|
||||||
initialState: sd1InitialPipelineModelsState,
|
|
||||||
reducers: {
|
|
||||||
modelAdded: sd1PipelineModelsAdapter.upsertOne,
|
|
||||||
},
|
|
||||||
extraReducers(builder) {
|
|
||||||
/**
|
|
||||||
* Received Models - FULFILLED
|
|
||||||
*/
|
|
||||||
builder.addCase(receivedModels.fulfilled, (state, action) => {
|
|
||||||
if (action.meta.arg.baseModel !== 'sd-1') return;
|
|
||||||
sd1PipelineModelsAdapter.setAll(state, action.payload);
|
|
||||||
});
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
export const {
|
|
||||||
selectAll: selectAllSD1PipelineModels,
|
|
||||||
selectById: selectByIdSD1PipelineModels,
|
|
||||||
selectEntities: selectEntitiesSD1PipelineModels,
|
|
||||||
selectIds: selectIdsSD1PipelineModels,
|
|
||||||
selectTotal: selectTotalSD1PipelineModels,
|
|
||||||
} = sd1PipelineModelsAdapter.getSelectors<RootState>(
|
|
||||||
(state) => state.sd1pipelinemodels
|
|
||||||
);
|
|
||||||
|
|
||||||
export const { modelAdded } = sd1PipelineModelsSlice.actions;
|
|
||||||
|
|
||||||
export default sd1PipelineModelsSlice.reducer;
|
|
@ -1,56 +0,0 @@
|
|||||||
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import {
|
|
||||||
StableDiffusion2ModelCheckpointConfig,
|
|
||||||
StableDiffusion2ModelDiffusersConfig,
|
|
||||||
} from 'services/api';
|
|
||||||
|
|
||||||
import { receivedModels } from 'services/thunks/model';
|
|
||||||
|
|
||||||
export type SD2PipelineModel = (
|
|
||||||
| StableDiffusion2ModelCheckpointConfig
|
|
||||||
| StableDiffusion2ModelDiffusersConfig
|
|
||||||
) & {
|
|
||||||
name: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const sd2PipelineModelsAdapater = createEntityAdapter<SD2PipelineModel>({
|
|
||||||
selectId: (model) => model.name,
|
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
|
||||||
});
|
|
||||||
|
|
||||||
export const sd2InitialPipelineModelsState =
|
|
||||||
sd2PipelineModelsAdapater.getInitialState();
|
|
||||||
|
|
||||||
export type SD2PipelineModelState = typeof sd2InitialPipelineModelsState;
|
|
||||||
|
|
||||||
export const sd2PipelineModelsSlice = createSlice({
|
|
||||||
name: 'sd2PipelineModels',
|
|
||||||
initialState: sd2InitialPipelineModelsState,
|
|
||||||
reducers: {
|
|
||||||
modelAdded: sd2PipelineModelsAdapater.upsertOne,
|
|
||||||
},
|
|
||||||
extraReducers(builder) {
|
|
||||||
/**
|
|
||||||
* Received Models - FULFILLED
|
|
||||||
*/
|
|
||||||
builder.addCase(receivedModels.fulfilled, (state, action) => {
|
|
||||||
if (action.meta.arg.baseModel !== 'sd-2') return;
|
|
||||||
sd2PipelineModelsAdapater.setAll(state, action.payload);
|
|
||||||
});
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
export const {
|
|
||||||
selectAll: selectAllSD2PipelineModels,
|
|
||||||
selectById: selectByIdSD2PipelineModels,
|
|
||||||
selectEntities: selectEntitiesSD2PipelineModels,
|
|
||||||
selectIds: selectIdsSD2PipelineModels,
|
|
||||||
selectTotal: selectTotalSD2PipelineModels,
|
|
||||||
} = sd2PipelineModelsAdapater.getSelectors<RootState>(
|
|
||||||
(state) => state.sd2pipelinemodels
|
|
||||||
);
|
|
||||||
|
|
||||||
export const { modelAdded } = sd2PipelineModelsSlice.actions;
|
|
||||||
|
|
||||||
export default sd2PipelineModelsSlice.reducer;
|
|
@ -1,9 +0,0 @@
|
|||||||
import { SD1PipelineModelState } from './models/sd1PipelineModelSlice';
|
|
||||||
import { SD2PipelineModelState } from './models/sd2PipelineModelSlice';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Models slice persist denylist
|
|
||||||
*/
|
|
||||||
export const modelsPersistDenylist:
|
|
||||||
| (keyof SD1PipelineModelState)[]
|
|
||||||
| (keyof SD2PipelineModelState)[] = ['entities', 'ids'];
|
|
@ -20,7 +20,6 @@ import {
|
|||||||
} from 'services/events/actions';
|
} from 'services/events/actions';
|
||||||
import { ProgressImage } from 'services/events/types';
|
import { ProgressImage } from 'services/events/types';
|
||||||
import { imageUploaded } from 'services/thunks/image';
|
import { imageUploaded } from 'services/thunks/image';
|
||||||
import { receivedModels } from 'services/thunks/model';
|
|
||||||
import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session';
|
import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session';
|
||||||
import { makeToast } from '../../../app/components/Toaster';
|
import { makeToast } from '../../../app/components/Toaster';
|
||||||
import { LANGUAGES } from '../components/LanguagePicker';
|
import { LANGUAGES } from '../components/LanguagePicker';
|
||||||
@ -377,13 +376,6 @@ export const systemSlice = createSlice({
|
|||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
/**
|
|
||||||
* Received available models from the backend
|
|
||||||
*/
|
|
||||||
builder.addCase(receivedModels.fulfilled, (state) => {
|
|
||||||
state.wereModelsReceived = true;
|
|
||||||
});
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* OpenAPI schema was parsed
|
* OpenAPI schema was parsed
|
||||||
*/
|
*/
|
||||||
|
@ -13,23 +13,68 @@ import {
|
|||||||
TagTypesFrom,
|
TagTypesFrom,
|
||||||
TagTypesFromApi,
|
TagTypesFromApi,
|
||||||
} from '@reduxjs/toolkit/dist/query/endpointDefinitions';
|
} from '@reduxjs/toolkit/dist/query/endpointDefinitions';
|
||||||
|
import { EntityState, createEntityAdapter } from '@reduxjs/toolkit';
|
||||||
|
import { BaseModelType } from './api/models/BaseModelType';
|
||||||
|
import { ModelType } from './api/models/ModelType';
|
||||||
|
import { ModelsList } from './api/models/ModelsList';
|
||||||
|
import { keyBy } from 'lodash-es';
|
||||||
|
|
||||||
type ListBoardsArg = { offset: number; limit: number };
|
type ListBoardsArg = { offset: number; limit: number };
|
||||||
type UpdateBoardArg = { board_id: string; changes: BoardChanges };
|
type UpdateBoardArg = { board_id: string; changes: BoardChanges };
|
||||||
type AddImageToBoardArg = { board_id: string; image_name: string };
|
type AddImageToBoardArg = { board_id: string; image_name: string };
|
||||||
type RemoveImageFromBoardArg = { board_id: string; image_name: string };
|
type RemoveImageFromBoardArg = { board_id: string; image_name: string };
|
||||||
type ListBoardImagesArg = { board_id: string; offset: number; limit: number };
|
type ListBoardImagesArg = { board_id: string; offset: number; limit: number };
|
||||||
|
type ListModelsArg = { base_model?: BaseModelType; model_type?: ModelType };
|
||||||
|
|
||||||
const tagTypes = ['Board', 'Image'];
|
type ModelConfig = ModelsList['models'][number];
|
||||||
|
|
||||||
|
const tagTypes = ['Board', 'Image', 'Model'];
|
||||||
type ApiFullTagDescription = FullTagDescription<(typeof tagTypes)[number]>;
|
type ApiFullTagDescription = FullTagDescription<(typeof tagTypes)[number]>;
|
||||||
|
|
||||||
const LIST = 'LIST';
|
const LIST = 'LIST';
|
||||||
|
|
||||||
|
const modelsAdapter = createEntityAdapter<ModelConfig>({
|
||||||
|
selectId: (model) => getModelId(model),
|
||||||
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||||
|
});
|
||||||
|
|
||||||
|
const getModelId = ({ base_model, type, name }: ModelConfig) =>
|
||||||
|
`${base_model}/${type}/${name}`;
|
||||||
|
|
||||||
export const api = createApi({
|
export const api = createApi({
|
||||||
baseQuery: fetchBaseQuery({ baseUrl: 'http://localhost:5173/api/v1/' }),
|
baseQuery: fetchBaseQuery({ baseUrl: 'http://localhost:5173/api/v1/' }),
|
||||||
reducerPath: 'api',
|
reducerPath: 'api',
|
||||||
tagTypes,
|
tagTypes,
|
||||||
endpoints: (build) => ({
|
endpoints: (build) => ({
|
||||||
|
/**
|
||||||
|
* Models Queries
|
||||||
|
*/
|
||||||
|
|
||||||
|
listModels: build.query<EntityState<ModelConfig>, ListModelsArg>({
|
||||||
|
query: (arg) => ({ url: 'models/', params: arg }),
|
||||||
|
providesTags: (result, error, arg) => {
|
||||||
|
// any list of boards
|
||||||
|
const tags: ApiFullTagDescription[] = [{ id: 'Model', type: LIST }];
|
||||||
|
|
||||||
|
if (result) {
|
||||||
|
// and individual tags for each board
|
||||||
|
tags.push(
|
||||||
|
...result.ids.map((id) => ({
|
||||||
|
type: 'Model' as const,
|
||||||
|
id,
|
||||||
|
}))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return tags;
|
||||||
|
},
|
||||||
|
transformResponse: (response: ModelsList, meta, arg) => {
|
||||||
|
return modelsAdapter.addMany(
|
||||||
|
modelsAdapter.getInitialState(),
|
||||||
|
keyBy(response.models, getModelId)
|
||||||
|
);
|
||||||
|
},
|
||||||
|
}),
|
||||||
/**
|
/**
|
||||||
* Boards Queries
|
* Boards Queries
|
||||||
*/
|
*/
|
||||||
@ -174,4 +219,5 @@ export const {
|
|||||||
useRemoveImageFromBoardMutation,
|
useRemoveImageFromBoardMutation,
|
||||||
useListBoardImagesQuery,
|
useListBoardImagesQuery,
|
||||||
useGetImageDTOQuery,
|
useGetImageDTOQuery,
|
||||||
|
useListModelsQuery,
|
||||||
} = api;
|
} = api;
|
||||||
|
@ -1,58 +0,0 @@
|
|||||||
import { log } from 'app/logging/useLogger';
|
|
||||||
import { createAppAsyncThunk } from 'app/store/storeUtils';
|
|
||||||
import { SD1PipelineModel } from 'features/system/store/models/sd1PipelineModelSlice';
|
|
||||||
import { SD2PipelineModel } from 'features/system/store/models/sd2PipelineModelSlice';
|
|
||||||
import { reduce, size } from 'lodash-es';
|
|
||||||
import { BaseModelType, ModelType, ModelsService } from 'services/api';
|
|
||||||
|
|
||||||
const models = log.child({ namespace: 'model' });
|
|
||||||
|
|
||||||
export const IMAGES_PER_PAGE = 20;
|
|
||||||
|
|
||||||
type receivedModelsArg = {
|
|
||||||
baseModel: BaseModelType | undefined;
|
|
||||||
modelType: ModelType | undefined;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const receivedModels = createAppAsyncThunk(
|
|
||||||
'models/receivedModels',
|
|
||||||
async (arg: receivedModelsArg) => {
|
|
||||||
const response = await ModelsService.listModels(arg);
|
|
||||||
|
|
||||||
let deserializedModels = {};
|
|
||||||
|
|
||||||
if (arg.baseModel === undefined) return response.models;
|
|
||||||
if (arg.modelType === undefined) return response.models;
|
|
||||||
|
|
||||||
if (arg.baseModel === 'sd-1') {
|
|
||||||
deserializedModels = reduce(
|
|
||||||
response.models[arg.baseModel][arg.modelType],
|
|
||||||
(modelsAccumulator, model, modelName) => {
|
|
||||||
modelsAccumulator[modelName] = { ...model, name: modelName };
|
|
||||||
return modelsAccumulator;
|
|
||||||
},
|
|
||||||
{} as Record<string, SD1PipelineModel>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (arg.baseModel === 'sd-2') {
|
|
||||||
deserializedModels = reduce(
|
|
||||||
response.models[arg.baseModel][arg.modelType],
|
|
||||||
(modelsAccumulator, model, modelName) => {
|
|
||||||
modelsAccumulator[modelName] = { ...model, name: modelName };
|
|
||||||
return modelsAccumulator;
|
|
||||||
},
|
|
||||||
{} as Record<string, SD2PipelineModel>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
models.info(
|
|
||||||
{ response },
|
|
||||||
`Received ${size(response.models[arg.baseModel][arg.modelType])} ${[
|
|
||||||
arg.baseModel,
|
|
||||||
]} models`
|
|
||||||
);
|
|
||||||
|
|
||||||
return deserializedModels;
|
|
||||||
}
|
|
||||||
);
|
|
Loading…
Reference in New Issue
Block a user