From db8862d86069ca9a4baea6e080ebda0bcf92e917 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 4 Jul 2023 21:12:52 +1000 Subject: [PATCH] feat(ui): add LoRA ui & update graphs --- .../enhancers/reduxRemember/serialize.ts | 2 - invokeai/frontend/web/src/app/store/store.ts | 23 +-- .../features/lora/components/ParamLora.tsx | 58 +++++++ .../lora/components/ParamLoraCollapse.tsx | 20 +++ .../lora/components/ParamLoraList.tsx | 19 +++ .../lora/components/ParamLoraSelect.tsx | 103 ++++++++++++ .../web/src/features/lora/store/loraSlice.ts | 44 +++++ .../nodes/components/InputFieldComponent.tsx | 11 ++ .../fields/LoRAModelInputFieldComponent.tsx | 104 ++++++++++++ .../src/features/nodes/store/nodesSlice.ts | 8 +- .../web/src/features/nodes/types/constants.ts | 9 +- .../web/src/features/nodes/types/types.ts | 13 ++ .../nodes/util/fieldTemplateBuilders.ts | 19 +++ .../features/nodes/util/fieldValueBuilders.ts | 4 + .../util/graphBuilders/addLoRAsToGraph.ts | 150 ++++++++++++++++++ .../buildCanvasImageToImageGraph.ts | 3 + .../buildCanvasTextToImageGraph.ts | 3 + .../buildLinearImageToImageGraph.ts | 4 + .../buildLinearTextToImageGraph.ts | 3 + .../util/graphBuilders/buildNodesGraph.ts | 7 + .../nodes/util/graphBuilders/constants.ts | 1 + .../features/nodes/util/modelIdToLoRAName.ts | 12 ++ .../ImageToImageTabParameters.tsx | 22 +-- .../TextToImage/TextToImageTabParameters.tsx | 20 +-- .../web/src/services/api/endpoints/models.ts | 4 +- .../frontend/web/src/services/api/types.d.ts | 3 + 26 files changed, 630 insertions(+), 39 deletions(-) create mode 100644 invokeai/frontend/web/src/features/lora/components/ParamLora.tsx create mode 100644 invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx create mode 100644 invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx create mode 100644 invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx create mode 100644 invokeai/frontend/web/src/features/lora/store/loraSlice.ts create mode 100644 invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts create mode 100644 invokeai/frontend/web/src/features/nodes/util/modelIdToLoRAName.ts diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts index cb18d48301..ac1b9c5205 100644 --- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts +++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts @@ -20,10 +20,8 @@ const serializationDenylist: { nodes: nodesPersistDenylist, postprocessing: postprocessingPersistDenylist, system: systemPersistDenylist, - // config: configPersistDenyList, ui: uiPersistDenylist, controlNet: controlNetDenylist, - // hotkeys: hotkeysPersistDenylist, }; export const serialize: SerializeFunction = (data, key) => { diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index 2fd071bd23..5208933e7b 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -8,31 +8,32 @@ import { import dynamicMiddlewares from 'redux-dynamic-middlewares'; import { rememberEnhancer, rememberReducer } from 'redux-remember'; +import batchReducer from 'features/batch/store/batchSlice'; import canvasReducer from 'features/canvas/store/canvasSlice'; import controlNetReducer from 'features/controlNet/store/controlNetSlice'; +import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice'; +import boardsReducer from 'features/gallery/store/boardSlice'; import galleryReducer from 'features/gallery/store/gallerySlice'; +import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice'; import lightboxReducer from 'features/lightbox/store/lightboxSlice'; +import loraReducer from 'features/lora/store/loraSlice'; +import nodesReducer from 'features/nodes/store/nodesSlice'; import generationReducer from 'features/parameters/store/generationSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; -import systemReducer from 'features/system/store/systemSlice'; -import nodesReducer from 'features/nodes/store/nodesSlice'; -import boardsReducer from 'features/gallery/store/boardSlice'; import configReducer from 'features/system/store/configSlice'; +import systemReducer from 'features/system/store/systemSlice'; import hotkeysReducer from 'features/ui/store/hotkeysSlice'; import uiReducer from 'features/ui/store/uiSlice'; -import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice'; -import batchReducer from 'features/batch/store/batchSlice'; -import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice'; import { listenerMiddleware } from './middleware/listenerMiddleware'; -import { actionSanitizer } from './middleware/devtools/actionSanitizer'; -import { actionsDenylist } from './middleware/devtools/actionsDenylist'; -import { stateSanitizer } from './middleware/devtools/stateSanitizer'; +import { api } from 'services/api'; import { LOCALSTORAGE_PREFIX } from './constants'; import { serialize } from './enhancers/reduxRemember/serialize'; import { unserialize } from './enhancers/reduxRemember/unserialize'; -import { api } from 'services/api'; +import { actionSanitizer } from './middleware/devtools/actionSanitizer'; +import { actionsDenylist } from './middleware/devtools/actionsDenylist'; +import { stateSanitizer } from './middleware/devtools/stateSanitizer'; const allReducers = { canvas: canvasReducer, @@ -50,6 +51,7 @@ const allReducers = { dynamicPrompts: dynamicPromptsReducer, batch: batchReducer, imageDeletion: imageDeletionReducer, + lora: loraReducer, [api.reducerPath]: api.reducer, }; @@ -69,6 +71,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [ 'controlNet', 'dynamicPrompts', 'batch', + 'lora', // 'boards', // 'hotkeys', // 'config', diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx new file mode 100644 index 0000000000..c7d1c44fd3 --- /dev/null +++ b/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx @@ -0,0 +1,58 @@ +import { Flex } from '@chakra-ui/react'; +import { useAppDispatch } from 'app/store/storeHooks'; +import IAIIconButton from 'common/components/IAIIconButton'; +import IAISlider from 'common/components/IAISlider'; +import { memo, useCallback } from 'react'; +import { FaTrash } from 'react-icons/fa'; +import { Lora, loraRemoved, loraWeightChanged } from '../store/loraSlice'; + +type Props = { + lora: Lora; +}; + +const ParamLora = (props: Props) => { + const dispatch = useAppDispatch(); + const { lora } = props; + + const handleChange = useCallback( + (v: number) => { + dispatch(loraWeightChanged({ name: lora.name, weight: v })); + }, + [dispatch, lora.name] + ); + + const handleReset = useCallback(() => { + dispatch(loraWeightChanged({ name: lora.name, weight: 1 })); + }, [dispatch, lora.name]); + + const handleRemoveLora = useCallback(() => { + dispatch(loraRemoved(lora.name)); + }, [dispatch, lora.name]); + + return ( + + + } + colorScheme="error" + /> + + ); +}; + +export default memo(ParamLora); diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx new file mode 100644 index 0000000000..fb088bef8a --- /dev/null +++ b/invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx @@ -0,0 +1,20 @@ +import { Flex, useDisclosure } from '@chakra-ui/react'; +import IAICollapse from 'common/components/IAICollapse'; +import { memo } from 'react'; +import ParamLoraList from './ParamLoraList'; +import ParamLoraSelect from './ParamLoraSelect'; + +const ParamLoraCollapse = () => { + const { isOpen, onToggle } = useDisclosure(); + + return ( + + + + + + + ); +}; + +export default memo(ParamLoraCollapse); diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx new file mode 100644 index 0000000000..8d6ff98498 --- /dev/null +++ b/invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx @@ -0,0 +1,19 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { map } from 'lodash-es'; +import ParamLora from './ParamLora'; + +const selector = createSelector(stateSelector, ({ lora }) => { + const { loras } = lora; + + return { loras }; +}); + +const ParamLoraList = () => { + const { loras } = useAppSelector(selector); + + return map(loras, (lora) => ); +}; + +export default ParamLoraList; diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx new file mode 100644 index 0000000000..8e44e7d8f1 --- /dev/null +++ b/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx @@ -0,0 +1,103 @@ +import { Text } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect'; +import { forEach } from 'lodash-es'; +import { forwardRef, useCallback, useMemo } from 'react'; +import { useListModelsQuery } from 'services/api/endpoints/models'; +import { loraAdded } from '../store/loraSlice'; + +type LoraSelectItem = { + label: string; + value: string; + description?: string; +}; + +const selector = createSelector( + stateSelector, + ({ lora }) => ({ + loras: lora.loras, + }), + defaultSelectorOptions +); + +const ParamLoraSelect = () => { + const dispatch = useAppDispatch(); + const { loras } = useAppSelector(selector); + const { data: lorasQueryData } = useListModelsQuery({ model_type: 'lora' }); + + const data = useMemo(() => { + if (!lorasQueryData) { + return []; + } + + const data: LoraSelectItem[] = []; + + forEach(lorasQueryData.entities, (lora, id) => { + if (!lora || Boolean(id in loras)) { + return; + } + + data.push({ + value: id, + label: lora.name, + description: lora.description, + }); + }); + + return data; + }, [loras, lorasQueryData]); + + const handleChange = useCallback( + (v: string[]) => { + v[0] && dispatch(loraAdded(v[0])); + }, + [dispatch] + ); + + return ( + + item.label.toLowerCase().includes(value.toLowerCase().trim()) || + item.value.toLowerCase().includes(value.toLowerCase().trim()) + } + onChange={handleChange} + /> + ); +}; + +interface ItemProps extends React.ComponentPropsWithoutRef<'div'> { + value: string; + label: string; + description?: string; +} + +const SelectItem = forwardRef( + ({ label, description, ...others }: ItemProps, ref) => { + return ( +
+
+ {label} + {description && ( + + {description} + + )} +
+
+ ); + } +); + +SelectItem.displayName = 'SelectItem'; + +export default ParamLoraSelect; diff --git a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts new file mode 100644 index 0000000000..49b316b054 --- /dev/null +++ b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts @@ -0,0 +1,44 @@ +import { PayloadAction, createSlice } from '@reduxjs/toolkit'; + +export type Lora = { + name: string; + weight: number; +}; + +export const defaultLoRAConfig: Omit = { + weight: 1, +}; + +export type LoraState = { + loras: Record; +}; + +export const intialLoraState: LoraState = { + loras: {}, +}; + +export const loraSlice = createSlice({ + name: 'lora', + initialState: intialLoraState, + reducers: { + loraAdded: (state, action: PayloadAction) => { + const name = action.payload; + state.loras[name] = { name, ...defaultLoRAConfig }; + }, + loraRemoved: (state, action: PayloadAction) => { + const name = action.payload; + delete state.loras[name]; + }, + loraWeightChanged: ( + state, + action: PayloadAction<{ name: string; weight: number }> + ) => { + const { name, weight } = action.payload; + state.loras[name].weight = weight; + }, + }, +}); + +export const { loraAdded, loraRemoved, loraWeightChanged } = loraSlice.actions; + +export default loraSlice.reducer; diff --git a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx index 062fec2fdc..9925a48381 100644 --- a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx @@ -12,6 +12,7 @@ import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFie import ImageInputFieldComponent from './fields/ImageInputFieldComponent'; import ItemInputFieldComponent from './fields/ItemInputFieldComponent'; import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent'; +import LoRAModelInputFieldComponent from './fields/LoRAModelInputFieldComponent'; import ModelInputFieldComponent from './fields/ModelInputFieldComponent'; import NumberInputFieldComponent from './fields/NumberInputFieldComponent'; import StringInputFieldComponent from './fields/StringInputFieldComponent'; @@ -163,6 +164,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => { ); } + if (type === 'lora_model' && template.type === 'lora_model') { + return ( + + ); + } + if (type === 'array' && template.type === 'array') { return ( +) => { + const { nodeId, field } = props; + + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + + const { data: loraModels } = useListModelsQuery({ + model_type: 'lora', + }); + + const selectedModel = useMemo( + () => loraModels?.entities[field.value ?? loraModels.ids[0]], + [loraModels?.entities, loraModels?.ids, field.value] + ); + + const data = useMemo(() => { + if (!loraModels) { + return []; + } + + const data: SelectItem[] = []; + + forEach(loraModels.entities, (model, id) => { + if (!model) { + return; + } + + data.push({ + value: id, + label: model.name, + group: BASE_MODEL_NAME_MAP[model.base_model], + }); + }); + + return data; + }, [loraModels]); + + const handleValueChanged = useCallback( + (v: string | null) => { + if (!v) { + return; + } + + dispatch( + fieldValueChanged({ + nodeId, + fieldName: field.name, + value: v, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + + useEffect(() => { + if (field.value && loraModels?.ids.includes(field.value)) { + return; + } + + const firstLora = loraModels?.ids[0]; + + if (!isString(firstLora)) { + return; + } + + handleValueChanged(firstLora); + }, [field.value, handleValueChanged, loraModels?.ids]); + + return ( + + ); +}; + +export default memo(LoRAModelInputFieldComponent); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index ffc93db2ba..4fa69c626b 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -1,5 +1,8 @@ import { createSlice, PayloadAction } from '@reduxjs/toolkit'; +import { RootState } from 'app/store/store'; +import { cloneDeep, uniqBy } from 'lodash-es'; import { OpenAPIV3 } from 'openapi-types'; +import { RgbaColor } from 'react-colorful'; import { addEdge, applyEdgeChanges, @@ -11,12 +14,9 @@ import { NodeChange, OnConnectStartParams, } from 'reactflow'; -import { ImageField } from 'services/api/types'; import { receivedOpenAPISchema } from 'services/api/thunks/schema'; +import { ImageField } from 'services/api/types'; import { InvocationTemplate, InvocationValue } from '../types/types'; -import { RgbaColor } from 'react-colorful'; -import { RootState } from 'app/store/store'; -import { cloneDeep, isArray, uniq, uniqBy } from 'lodash-es'; export type NodesState = { nodes: Node[]; diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index b864501803..5fe780a286 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -18,6 +18,7 @@ export const FIELD_TYPE_MAP: Record = { VaeField: 'vae', model: 'model', vae_model: 'vae_model', + lora_model: 'lora_model', array: 'array', item: 'item', ColorField: 'color', @@ -120,7 +121,13 @@ export const FIELDS: Record = { vae_model: { color: 'teal', colorCssVar: getColorTokenCssVariable('teal'), - title: 'Model', + title: 'VAE', + description: 'Models are models.', + }, + lora_model: { + color: 'teal', + colorCssVar: getColorTokenCssVariable('teal'), + title: 'LoRA', description: 'Models are models.', }, array: { diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index c7e573ace2..3de8cae9ff 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -65,6 +65,7 @@ export type FieldType = | 'control' | 'model' | 'vae_model' + | 'lora_model' | 'array' | 'item' | 'color' @@ -93,6 +94,7 @@ export type InputFieldValue = | EnumInputFieldValue | ModelInputFieldValue | VaeModelInputFieldValue + | LoRAModelInputFieldValue | ArrayInputFieldValue | ItemInputFieldValue | ColorInputFieldValue @@ -119,6 +121,7 @@ export type InputFieldTemplate = | EnumInputFieldTemplate | ModelInputFieldTemplate | VaeModelInputFieldTemplate + | LoRAModelInputFieldTemplate | ArrayInputFieldTemplate | ItemInputFieldTemplate | ColorInputFieldTemplate @@ -236,6 +239,11 @@ export type VaeModelInputFieldValue = FieldValueBase & { value?: string; }; +export type LoRAModelInputFieldValue = FieldValueBase & { + type: 'lora_model'; + value?: string; +}; + export type ArrayInputFieldValue = FieldValueBase & { type: 'array'; value?: (string | number)[]; @@ -350,6 +358,11 @@ export type VaeModelInputFieldTemplate = InputFieldTemplateBase & { type: 'vae_model'; }; +export type LoRAModelInputFieldTemplate = InputFieldTemplateBase & { + default: string; + type: 'lora_model'; +}; + export type ArrayInputFieldTemplate = InputFieldTemplateBase & { default: []; type: 'array'; diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts index c71618175a..1c2dbc0c3e 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts @@ -18,6 +18,7 @@ import { IntegerInputFieldTemplate, ItemInputFieldTemplate, LatentsInputFieldTemplate, + LoRAModelInputFieldTemplate, ModelInputFieldTemplate, OutputFieldTemplate, StringInputFieldTemplate, @@ -191,6 +192,21 @@ const buildVaeModelInputFieldTemplate = ({ return template; }; +const buildLoRAModelInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): LoRAModelInputFieldTemplate => { + const template: LoRAModelInputFieldTemplate = { + ...baseField, + type: 'lora_model', + inputRequirement: 'always', + inputKind: 'direct', + default: schemaObject.default ?? undefined, + }; + + return template; +}; + const buildImageInputFieldTemplate = ({ schemaObject, baseField, @@ -460,6 +476,9 @@ export const buildInputFieldTemplate = ( if (['vae_model'].includes(fieldType)) { return buildVaeModelInputFieldTemplate({ schemaObject, baseField }); } + if (['lora_model'].includes(fieldType)) { + return buildLoRAModelInputFieldTemplate({ schemaObject, baseField }); + } if (['enum'].includes(fieldType)) { return buildEnumInputFieldTemplate({ schemaObject, baseField }); } diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts index a94d3ddef2..950038b691 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts @@ -79,6 +79,10 @@ export const buildInputFieldValue = ( if (template.type === 'vae_model') { fieldValue.value = undefined; } + + if (template.type === 'lora_model') { + fieldValue.value = undefined; + } } return fieldValue; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts new file mode 100644 index 0000000000..a105a123d8 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts @@ -0,0 +1,150 @@ +import { RootState } from 'app/store/store'; +import { NonNullableGraph } from 'features/nodes/types/types'; +import { forEach, size } from 'lodash-es'; +import { LoraLoaderInvocation } from 'services/api/types'; +import { modelIdToLoRAModelField } from '../modelIdToLoRAName'; +import { + LORA_LOADER, + MAIN_MODEL_LOADER, + NEGATIVE_CONDITIONING, + POSITIVE_CONDITIONING, +} from './constants'; + +export const addLoRAsToGraph = ( + graph: NonNullableGraph, + state: RootState, + baseNodeId: string +): void => { + /** + * LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them. + * They then output the UNet and CLIP models references on to either the next LoRA in the chain, + * or to the inference/conditioning nodes. + * + * So we need to inject a LoRA chain into the graph. + */ + + const { loras } = state.lora; + const loraCount = size(loras); + + if (loraCount > 0) { + // remove any existing connections from main model loader, we need to insert the lora nodes + graph.edges = graph.edges.filter( + (e) => + !( + e.source.node_id === MAIN_MODEL_LOADER && + ['unet', 'clip'].includes(e.source.field) + ) + ); + } + + // we need to remember the last lora so we can chain from it + let lastLoraNodeId = ''; + let currentLoraIndex = 0; + + forEach(loras, (lora) => { + const { name, weight } = lora; + const loraField = modelIdToLoRAModelField(name); + const currentLoraNodeId = `${LORA_LOADER}_${loraField.model_name.replace( + '.', + '_' + )}`; + + console.log(lastLoraNodeId, currentLoraNodeId, currentLoraIndex, loraField); + + const loraLoaderNode: LoraLoaderInvocation = { + type: 'lora_loader', + id: currentLoraNodeId, + lora: loraField, + weight, + }; + + graph.nodes[currentLoraNodeId] = loraLoaderNode; + + if (currentLoraIndex === 0) { + // first lora = start the lora chain, attach directly to model loader + graph.edges.push({ + source: { + node_id: MAIN_MODEL_LOADER, + field: 'unet', + }, + destination: { + node_id: currentLoraNodeId, + field: 'unet', + }, + }); + + graph.edges.push({ + source: { + node_id: MAIN_MODEL_LOADER, + field: 'clip', + }, + destination: { + node_id: currentLoraNodeId, + field: 'clip', + }, + }); + } else { + // we are in the middle of the lora chain, instead connect to the previous lora + graph.edges.push({ + source: { + node_id: lastLoraNodeId, + field: 'unet', + }, + destination: { + node_id: currentLoraNodeId, + field: 'unet', + }, + }); + graph.edges.push({ + source: { + node_id: lastLoraNodeId, + field: 'clip', + }, + destination: { + node_id: currentLoraNodeId, + field: 'clip', + }, + }); + } + + if (currentLoraIndex === loraCount - 1) { + // final lora, end the lora chain - we need to connect up to inference and conditioning nodes + graph.edges.push({ + source: { + node_id: currentLoraNodeId, + field: 'unet', + }, + destination: { + node_id: baseNodeId, + field: 'unet', + }, + }); + + graph.edges.push({ + source: { + node_id: currentLoraNodeId, + field: 'clip', + }, + destination: { + node_id: POSITIVE_CONDITIONING, + field: 'clip', + }, + }); + + graph.edges.push({ + source: { + node_id: currentLoraNodeId, + field: 'clip', + }, + destination: { + node_id: NEGATIVE_CONDITIONING, + field: 'clip', + }, + }); + } + + // increment the lora for the next one in the chain + lastLoraNodeId = currentLoraNodeId; + currentLoraIndex += 1; + }); +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts index 5cf9882ac1..1843efef84 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts @@ -9,6 +9,7 @@ import { import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { modelIdToMainModelField } from '../modelIdToMainModelField'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; +import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addVAEToGraph } from './addVAEToGraph'; import { IMAGE_TO_IMAGE_GRAPH, @@ -252,6 +253,8 @@ export const buildCanvasImageToImageGraph = ( }); } + addLoRAsToGraph(graph, state, LATENTS_TO_LATENTS); + // Add VAE addVAEToGraph(graph, state); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts index cfe5e62805..976ea4fd01 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts @@ -3,6 +3,7 @@ import { NonNullableGraph } from 'features/nodes/types/types'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { modelIdToMainModelField } from '../modelIdToMainModelField'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; +import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addVAEToGraph } from './addVAEToGraph'; import { LATENTS_TO_IMAGE, @@ -157,6 +158,8 @@ export const buildCanvasTextToImageGraph = ( ], }; + addLoRAsToGraph(graph, state, TEXT_TO_LATENTS); + // Add VAE addVAEToGraph(graph, state); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts index 2e4383c3e7..fe6d1292e4 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts @@ -10,6 +10,7 @@ import { import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { modelIdToMainModelField } from '../modelIdToMainModelField'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; +import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addVAEToGraph } from './addVAEToGraph'; import { IMAGE_COLLECTION, @@ -304,6 +305,9 @@ export const buildLinearImageToImageGraph = ( }, }); } + + addLoRAsToGraph(graph, state, LATENTS_TO_LATENTS); + // Add VAE addVAEToGraph(graph, state); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts index e0e71a00a2..04dccf4983 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts @@ -3,6 +3,7 @@ import { NonNullableGraph } from 'features/nodes/types/types'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { modelIdToMainModelField } from '../modelIdToMainModelField'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; +import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addVAEToGraph } from './addVAEToGraph'; import { LATENTS_TO_IMAGE, @@ -150,6 +151,8 @@ export const buildLinearTextToImageGraph = ( ], }; + addLoRAsToGraph(graph, state, TEXT_TO_LATENTS); + // Add Custom VAE Support addVAEToGraph(graph, state); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts index 3265a0f889..12a567b009 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts @@ -4,6 +4,7 @@ import { cloneDeep, omit, reduce } from 'lodash-es'; import { Graph } from 'services/api/types'; import { AnyInvocation } from 'services/events/types'; import { v4 as uuidv4 } from 'uuid'; +import { modelIdToLoRAModelField } from '../modelIdToLoRAName'; import { modelIdToMainModelField } from '../modelIdToMainModelField'; import { modelIdToVAEModelField } from '../modelIdToVAEModelField'; @@ -38,6 +39,12 @@ export const parseFieldValue = (field: InputFieldValue) => { } } + if (field.type === 'lora_model') { + if (field.value) { + return modelIdToLoRAModelField(field.value); + } + } + return field.value; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts index 58a7d0335b..7aace48def 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts @@ -9,6 +9,7 @@ export const RANGE_OF_SIZE = 'range_of_size'; export const ITERATE = 'iterate'; export const MAIN_MODEL_LOADER = 'main_model_loader'; export const VAE_LOADER = 'vae_loader'; +export const LORA_LOADER = 'lora_loader'; export const IMAGE_TO_LATENTS = 'image_to_latents'; export const LATENTS_TO_LATENTS = 'latents_to_latents'; export const RESIZE = 'resize_image'; diff --git a/invokeai/frontend/web/src/features/nodes/util/modelIdToLoRAName.ts b/invokeai/frontend/web/src/features/nodes/util/modelIdToLoRAName.ts new file mode 100644 index 0000000000..052b58484b --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/modelIdToLoRAName.ts @@ -0,0 +1,12 @@ +import { BaseModelType, LoRAModelField } from 'services/api/types'; + +export const modelIdToLoRAModelField = (loraId: string): LoRAModelField => { + const [base_model, model_type, model_name] = loraId.split('/'); + + const field: LoRAModelField = { + base_model: base_model as BaseModelType, + model_name, + }; + + return field; +}; diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabParameters.tsx index 4f04abffa1..32b71d6187 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabParameters.tsx @@ -1,14 +1,15 @@ -import { memo } from 'react'; -import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; -import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning'; -import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning'; -import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; -import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse'; -import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; -import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse'; -import ImageToImageTabCoreParameters from './ImageToImageTabCoreParameters'; -import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse'; +import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse'; +import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; +import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning'; +import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning'; +import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse'; +import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse'; +import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; +import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; +import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; +import { memo } from 'react'; +import ImageToImageTabCoreParameters from './ImageToImageTabCoreParameters'; const ImageToImageTabParameters = () => { return ( @@ -17,6 +18,7 @@ const ImageToImageTabParameters = () => { + diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabParameters.tsx index bcc6c91ae6..6291b69a8e 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabParameters.tsx @@ -1,15 +1,16 @@ +import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse'; +import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse'; +import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; +import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning'; +import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning'; +import ParamHiresCollapse from 'features/parameters/components/Parameters/Hires/ParamHiresCollapse'; +import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse'; +import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse'; +import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; +import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; import { memo } from 'react'; -import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning'; -import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning'; -import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; -import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse'; -import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; -import ParamHiresCollapse from 'features/parameters/components/Parameters/Hires/ParamHiresCollapse'; -import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse'; import TextToImageTabCoreParameters from './TextToImageTabCoreParameters'; -import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; -import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse'; const TextToImageTabParameters = () => { return ( @@ -18,6 +19,7 @@ const TextToImageTabParameters = () => { + diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 39e4e46d3b..bff412bacb 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -1,6 +1,6 @@ -import { ModelsList } from 'services/api/types'; import { EntityState, createEntityAdapter } from '@reduxjs/toolkit'; import { keyBy } from 'lodash-es'; +import { ModelsList } from 'services/api/types'; import { ApiFullTagDescription, LIST_TAG, api } from '..'; import { paths } from '../schema'; @@ -24,11 +24,9 @@ export const modelsApi = api.injectEndpoints({ listModels: build.query, ListModelsArg>({ query: (arg) => ({ url: 'models/', params: arg }), providesTags: (result, error, arg) => { - // any list of boards const tags: ApiFullTagDescription[] = [{ id: 'Model', type: LIST_TAG }]; if (result) { - // and individual tags for each board tags.push( ...result.ids.map((id) => ({ type: 'Model' as const, diff --git a/invokeai/frontend/web/src/services/api/types.d.ts b/invokeai/frontend/web/src/services/api/types.d.ts index 18942a47d6..6f97dd1dbb 100644 --- a/invokeai/frontend/web/src/services/api/types.d.ts +++ b/invokeai/frontend/web/src/services/api/types.d.ts @@ -35,7 +35,9 @@ export type ModelType = S<'ModelType'>; export type BaseModelType = S<'BaseModelType'>; export type MainModelField = S<'MainModelField'>; export type VAEModelField = S<'VAEModelField'>; +export type LoRAModelField = S<'LoRAModelField'>; export type ModelsList = S<'ModelsList'>; +export type LoRAModelConfig = S<'LoRAModelConfig'>; // Graphs export type Graph = S<'Graph'>; @@ -60,6 +62,7 @@ export type ImageToLatentsInvocation = N<'ImageToLatentsInvocation'>; export type LatentsToImageInvocation = N<'LatentsToImageInvocation'>; export type ImageCollectionInvocation = N<'ImageCollectionInvocation'>; export type MainModelLoaderInvocation = N<'MainModelLoaderInvocation'>; +export type LoraLoaderInvocation = N<'LoraLoaderInvocation'>; // ControlNet Nodes export type ControlNetInvocation = N<'ControlNetInvocation'>;