mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): add LoRA ui & update graphs
This commit is contained in:
parent
d537b9f0cb
commit
db8862d860
@ -20,10 +20,8 @@ const serializationDenylist: {
|
||||
nodes: nodesPersistDenylist,
|
||||
postprocessing: postprocessingPersistDenylist,
|
||||
system: systemPersistDenylist,
|
||||
// config: configPersistDenyList,
|
||||
ui: uiPersistDenylist,
|
||||
controlNet: controlNetDenylist,
|
||||
// hotkeys: hotkeysPersistDenylist,
|
||||
};
|
||||
|
||||
export const serialize: SerializeFunction = (data, key) => {
|
||||
|
@ -8,31 +8,32 @@ import {
|
||||
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
||||
import { rememberEnhancer, rememberReducer } from 'redux-remember';
|
||||
|
||||
import batchReducer from 'features/batch/store/batchSlice';
|
||||
import canvasReducer from 'features/canvas/store/canvasSlice';
|
||||
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
|
||||
import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice';
|
||||
import boardsReducer from 'features/gallery/store/boardSlice';
|
||||
import galleryReducer from 'features/gallery/store/gallerySlice';
|
||||
import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice';
|
||||
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
|
||||
import loraReducer from 'features/lora/store/loraSlice';
|
||||
import nodesReducer from 'features/nodes/store/nodesSlice';
|
||||
import generationReducer from 'features/parameters/store/generationSlice';
|
||||
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
||||
import systemReducer from 'features/system/store/systemSlice';
|
||||
import nodesReducer from 'features/nodes/store/nodesSlice';
|
||||
import boardsReducer from 'features/gallery/store/boardSlice';
|
||||
import configReducer from 'features/system/store/configSlice';
|
||||
import systemReducer from 'features/system/store/systemSlice';
|
||||
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
||||
import uiReducer from 'features/ui/store/uiSlice';
|
||||
import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice';
|
||||
import batchReducer from 'features/batch/store/batchSlice';
|
||||
import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice';
|
||||
|
||||
import { listenerMiddleware } from './middleware/listenerMiddleware';
|
||||
|
||||
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
|
||||
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
|
||||
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
|
||||
import { api } from 'services/api';
|
||||
import { LOCALSTORAGE_PREFIX } from './constants';
|
||||
import { serialize } from './enhancers/reduxRemember/serialize';
|
||||
import { unserialize } from './enhancers/reduxRemember/unserialize';
|
||||
import { api } from 'services/api';
|
||||
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
|
||||
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
|
||||
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
|
||||
|
||||
const allReducers = {
|
||||
canvas: canvasReducer,
|
||||
@ -50,6 +51,7 @@ const allReducers = {
|
||||
dynamicPrompts: dynamicPromptsReducer,
|
||||
batch: batchReducer,
|
||||
imageDeletion: imageDeletionReducer,
|
||||
lora: loraReducer,
|
||||
[api.reducerPath]: api.reducer,
|
||||
};
|
||||
|
||||
@ -69,6 +71,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
|
||||
'controlNet',
|
||||
'dynamicPrompts',
|
||||
'batch',
|
||||
'lora',
|
||||
// 'boards',
|
||||
// 'hotkeys',
|
||||
// 'config',
|
||||
|
@ -0,0 +1,58 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { FaTrash } from 'react-icons/fa';
|
||||
import { Lora, loraRemoved, loraWeightChanged } from '../store/loraSlice';
|
||||
|
||||
type Props = {
|
||||
lora: Lora;
|
||||
};
|
||||
|
||||
const ParamLora = (props: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { lora } = props;
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: number) => {
|
||||
dispatch(loraWeightChanged({ name: lora.name, weight: v }));
|
||||
},
|
||||
[dispatch, lora.name]
|
||||
);
|
||||
|
||||
const handleReset = useCallback(() => {
|
||||
dispatch(loraWeightChanged({ name: lora.name, weight: 1 }));
|
||||
}, [dispatch, lora.name]);
|
||||
|
||||
const handleRemoveLora = useCallback(() => {
|
||||
dispatch(loraRemoved(lora.name));
|
||||
}, [dispatch, lora.name]);
|
||||
|
||||
return (
|
||||
<Flex sx={{ gap: 2.5, alignItems: 'flex-end' }}>
|
||||
<IAISlider
|
||||
label={lora.name}
|
||||
value={lora.weight}
|
||||
onChange={handleChange}
|
||||
min={0}
|
||||
max={1}
|
||||
step={0.01}
|
||||
withInput
|
||||
withReset
|
||||
handleReset={handleReset}
|
||||
withSliderMarks
|
||||
/>
|
||||
<IAIIconButton
|
||||
size="sm"
|
||||
onClick={handleRemoveLora}
|
||||
tooltip="Remove LoRA"
|
||||
aria-label="Remove LoRA"
|
||||
icon={<FaTrash />}
|
||||
colorScheme="error"
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamLora);
|
@ -0,0 +1,20 @@
|
||||
import { Flex, useDisclosure } from '@chakra-ui/react';
|
||||
import IAICollapse from 'common/components/IAICollapse';
|
||||
import { memo } from 'react';
|
||||
import ParamLoraList from './ParamLoraList';
|
||||
import ParamLoraSelect from './ParamLoraSelect';
|
||||
|
||||
const ParamLoraCollapse = () => {
|
||||
const { isOpen, onToggle } = useDisclosure();
|
||||
|
||||
return (
|
||||
<IAICollapse label="LoRAs" isOpen={isOpen} onToggle={onToggle}>
|
||||
<Flex sx={{ flexDir: 'column', gap: 2 }}>
|
||||
<ParamLoraSelect />
|
||||
<ParamLoraList />
|
||||
</Flex>
|
||||
</IAICollapse>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamLoraCollapse);
|
@ -0,0 +1,19 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { map } from 'lodash-es';
|
||||
import ParamLora from './ParamLora';
|
||||
|
||||
const selector = createSelector(stateSelector, ({ lora }) => {
|
||||
const { loras } = lora;
|
||||
|
||||
return { loras };
|
||||
});
|
||||
|
||||
const ParamLoraList = () => {
|
||||
const { loras } = useAppSelector(selector);
|
||||
|
||||
return map(loras, (lora) => <ParamLora key={lora.name} lora={lora} />);
|
||||
};
|
||||
|
||||
export default ParamLoraList;
|
@ -0,0 +1,103 @@
|
||||
import { Text } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { forwardRef, useCallback, useMemo } from 'react';
|
||||
import { useListModelsQuery } from 'services/api/endpoints/models';
|
||||
import { loraAdded } from '../store/loraSlice';
|
||||
|
||||
type LoraSelectItem = {
|
||||
label: string;
|
||||
value: string;
|
||||
description?: string;
|
||||
};
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
({ lora }) => ({
|
||||
loras: lora.loras,
|
||||
}),
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamLoraSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { loras } = useAppSelector(selector);
|
||||
const { data: lorasQueryData } = useListModelsQuery({ model_type: 'lora' });
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!lorasQueryData) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: LoraSelectItem[] = [];
|
||||
|
||||
forEach(lorasQueryData.entities, (lora, id) => {
|
||||
if (!lora || Boolean(id in loras)) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: lora.name,
|
||||
description: lora.description,
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [loras, lorasQueryData]);
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: string[]) => {
|
||||
v[0] && dispatch(loraAdded(v[0]));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIMantineMultiSelect
|
||||
placeholder={data.length === 0 ? 'All LoRAs added' : 'Add LoRA'}
|
||||
value={[]}
|
||||
data={data}
|
||||
maxDropdownHeight={400}
|
||||
nothingFound="No matching LoRAs"
|
||||
itemComponent={SelectItem}
|
||||
disabled={data.length === 0}
|
||||
filter={(value, selected, item: LoraSelectItem) =>
|
||||
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||
item.value.toLowerCase().includes(value.toLowerCase().trim())
|
||||
}
|
||||
onChange={handleChange}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
|
||||
value: string;
|
||||
label: string;
|
||||
description?: string;
|
||||
}
|
||||
|
||||
const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
|
||||
({ label, description, ...others }: ItemProps, ref) => {
|
||||
return (
|
||||
<div ref={ref} {...others}>
|
||||
<div>
|
||||
<Text>{label}</Text>
|
||||
{description && (
|
||||
<Text size="xs" color="base.600">
|
||||
{description}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
SelectItem.displayName = 'SelectItem';
|
||||
|
||||
export default ParamLoraSelect;
|
44
invokeai/frontend/web/src/features/lora/store/loraSlice.ts
Normal file
44
invokeai/frontend/web/src/features/lora/store/loraSlice.ts
Normal file
@ -0,0 +1,44 @@
|
||||
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||
|
||||
export type Lora = {
|
||||
name: string;
|
||||
weight: number;
|
||||
};
|
||||
|
||||
export const defaultLoRAConfig: Omit<Lora, 'name'> = {
|
||||
weight: 1,
|
||||
};
|
||||
|
||||
export type LoraState = {
|
||||
loras: Record<string, Lora>;
|
||||
};
|
||||
|
||||
export const intialLoraState: LoraState = {
|
||||
loras: {},
|
||||
};
|
||||
|
||||
export const loraSlice = createSlice({
|
||||
name: 'lora',
|
||||
initialState: intialLoraState,
|
||||
reducers: {
|
||||
loraAdded: (state, action: PayloadAction<string>) => {
|
||||
const name = action.payload;
|
||||
state.loras[name] = { name, ...defaultLoRAConfig };
|
||||
},
|
||||
loraRemoved: (state, action: PayloadAction<string>) => {
|
||||
const name = action.payload;
|
||||
delete state.loras[name];
|
||||
},
|
||||
loraWeightChanged: (
|
||||
state,
|
||||
action: PayloadAction<{ name: string; weight: number }>
|
||||
) => {
|
||||
const { name, weight } = action.payload;
|
||||
state.loras[name].weight = weight;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
export const { loraAdded, loraRemoved, loraWeightChanged } = loraSlice.actions;
|
||||
|
||||
export default loraSlice.reducer;
|
@ -12,6 +12,7 @@ import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFie
|
||||
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
|
||||
import ItemInputFieldComponent from './fields/ItemInputFieldComponent';
|
||||
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
|
||||
import LoRAModelInputFieldComponent from './fields/LoRAModelInputFieldComponent';
|
||||
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
|
||||
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
|
||||
import StringInputFieldComponent from './fields/StringInputFieldComponent';
|
||||
@ -163,6 +164,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'lora_model' && template.type === 'lora_model') {
|
||||
return (
|
||||
<LoRAModelInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'array' && template.type === 'array') {
|
||||
return (
|
||||
<ArrayInputFieldComponent
|
||||
|
@ -0,0 +1,104 @@
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
VaeModelInputFieldTemplate,
|
||||
VaeModelInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
||||
import { forEach, isString } from 'lodash-es';
|
||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useListModelsQuery } from 'services/api/endpoints/models';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
const LoRAModelInputFieldComponent = (
|
||||
props: FieldComponentProps<
|
||||
VaeModelInputFieldValue,
|
||||
VaeModelInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { data: loraModels } = useListModelsQuery({
|
||||
model_type: 'lora',
|
||||
});
|
||||
|
||||
const selectedModel = useMemo(
|
||||
() => loraModels?.entities[field.value ?? loraModels.ids[0]],
|
||||
[loraModels?.entities, loraModels?.ids, field.value]
|
||||
);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!loraModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(loraModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.name,
|
||||
group: BASE_MODEL_NAME_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [loraModels]);
|
||||
|
||||
const handleValueChanged = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
fieldValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: v,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (field.value && loraModels?.ids.includes(field.value)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const firstLora = loraModels?.ids[0];
|
||||
|
||||
if (!isString(firstLora)) {
|
||||
return;
|
||||
}
|
||||
|
||||
handleValueChanged(firstLora);
|
||||
}, [field.value, handleValueChanged, loraModels?.ids]);
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
tooltip={selectedModel?.description}
|
||||
label={
|
||||
selectedModel?.base_model &&
|
||||
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
|
||||
}
|
||||
value={field.value}
|
||||
placeholder="Pick one"
|
||||
data={data}
|
||||
onChange={handleValueChanged}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(LoRAModelInputFieldComponent);
|
@ -1,5 +1,8 @@
|
||||
import { createSlice, PayloadAction } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { cloneDeep, uniqBy } from 'lodash-es';
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
import { RgbaColor } from 'react-colorful';
|
||||
import {
|
||||
addEdge,
|
||||
applyEdgeChanges,
|
||||
@ -11,12 +14,9 @@ import {
|
||||
NodeChange,
|
||||
OnConnectStartParams,
|
||||
} from 'reactflow';
|
||||
import { ImageField } from 'services/api/types';
|
||||
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
||||
import { ImageField } from 'services/api/types';
|
||||
import { InvocationTemplate, InvocationValue } from '../types/types';
|
||||
import { RgbaColor } from 'react-colorful';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { cloneDeep, isArray, uniq, uniqBy } from 'lodash-es';
|
||||
|
||||
export type NodesState = {
|
||||
nodes: Node<InvocationValue>[];
|
||||
|
@ -18,6 +18,7 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
|
||||
VaeField: 'vae',
|
||||
model: 'model',
|
||||
vae_model: 'vae_model',
|
||||
lora_model: 'lora_model',
|
||||
array: 'array',
|
||||
item: 'item',
|
||||
ColorField: 'color',
|
||||
@ -120,7 +121,13 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
vae_model: {
|
||||
color: 'teal',
|
||||
colorCssVar: getColorTokenCssVariable('teal'),
|
||||
title: 'Model',
|
||||
title: 'VAE',
|
||||
description: 'Models are models.',
|
||||
},
|
||||
lora_model: {
|
||||
color: 'teal',
|
||||
colorCssVar: getColorTokenCssVariable('teal'),
|
||||
title: 'LoRA',
|
||||
description: 'Models are models.',
|
||||
},
|
||||
array: {
|
||||
|
@ -65,6 +65,7 @@ export type FieldType =
|
||||
| 'control'
|
||||
| 'model'
|
||||
| 'vae_model'
|
||||
| 'lora_model'
|
||||
| 'array'
|
||||
| 'item'
|
||||
| 'color'
|
||||
@ -93,6 +94,7 @@ export type InputFieldValue =
|
||||
| EnumInputFieldValue
|
||||
| ModelInputFieldValue
|
||||
| VaeModelInputFieldValue
|
||||
| LoRAModelInputFieldValue
|
||||
| ArrayInputFieldValue
|
||||
| ItemInputFieldValue
|
||||
| ColorInputFieldValue
|
||||
@ -119,6 +121,7 @@ export type InputFieldTemplate =
|
||||
| EnumInputFieldTemplate
|
||||
| ModelInputFieldTemplate
|
||||
| VaeModelInputFieldTemplate
|
||||
| LoRAModelInputFieldTemplate
|
||||
| ArrayInputFieldTemplate
|
||||
| ItemInputFieldTemplate
|
||||
| ColorInputFieldTemplate
|
||||
@ -236,6 +239,11 @@ export type VaeModelInputFieldValue = FieldValueBase & {
|
||||
value?: string;
|
||||
};
|
||||
|
||||
export type LoRAModelInputFieldValue = FieldValueBase & {
|
||||
type: 'lora_model';
|
||||
value?: string;
|
||||
};
|
||||
|
||||
export type ArrayInputFieldValue = FieldValueBase & {
|
||||
type: 'array';
|
||||
value?: (string | number)[];
|
||||
@ -350,6 +358,11 @@ export type VaeModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'vae_model';
|
||||
};
|
||||
|
||||
export type LoRAModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: string;
|
||||
type: 'lora_model';
|
||||
};
|
||||
|
||||
export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: [];
|
||||
type: 'array';
|
||||
|
@ -18,6 +18,7 @@ import {
|
||||
IntegerInputFieldTemplate,
|
||||
ItemInputFieldTemplate,
|
||||
LatentsInputFieldTemplate,
|
||||
LoRAModelInputFieldTemplate,
|
||||
ModelInputFieldTemplate,
|
||||
OutputFieldTemplate,
|
||||
StringInputFieldTemplate,
|
||||
@ -191,6 +192,21 @@ const buildVaeModelInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildLoRAModelInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): LoRAModelInputFieldTemplate => {
|
||||
const template: LoRAModelInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'lora_model',
|
||||
inputRequirement: 'always',
|
||||
inputKind: 'direct',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildImageInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -460,6 +476,9 @@ export const buildInputFieldTemplate = (
|
||||
if (['vae_model'].includes(fieldType)) {
|
||||
return buildVaeModelInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['lora_model'].includes(fieldType)) {
|
||||
return buildLoRAModelInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['enum'].includes(fieldType)) {
|
||||
return buildEnumInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
|
@ -79,6 +79,10 @@ export const buildInputFieldValue = (
|
||||
if (template.type === 'vae_model') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'lora_model') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
return fieldValue;
|
||||
|
@ -0,0 +1,150 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { forEach, size } from 'lodash-es';
|
||||
import { LoraLoaderInvocation } from 'services/api/types';
|
||||
import { modelIdToLoRAModelField } from '../modelIdToLoRAName';
|
||||
import {
|
||||
LORA_LOADER,
|
||||
MAIN_MODEL_LOADER,
|
||||
NEGATIVE_CONDITIONING,
|
||||
POSITIVE_CONDITIONING,
|
||||
} from './constants';
|
||||
|
||||
export const addLoRAsToGraph = (
|
||||
graph: NonNullableGraph,
|
||||
state: RootState,
|
||||
baseNodeId: string
|
||||
): void => {
|
||||
/**
|
||||
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
|
||||
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
|
||||
* or to the inference/conditioning nodes.
|
||||
*
|
||||
* So we need to inject a LoRA chain into the graph.
|
||||
*/
|
||||
|
||||
const { loras } = state.lora;
|
||||
const loraCount = size(loras);
|
||||
|
||||
if (loraCount > 0) {
|
||||
// remove any existing connections from main model loader, we need to insert the lora nodes
|
||||
graph.edges = graph.edges.filter(
|
||||
(e) =>
|
||||
!(
|
||||
e.source.node_id === MAIN_MODEL_LOADER &&
|
||||
['unet', 'clip'].includes(e.source.field)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// we need to remember the last lora so we can chain from it
|
||||
let lastLoraNodeId = '';
|
||||
let currentLoraIndex = 0;
|
||||
|
||||
forEach(loras, (lora) => {
|
||||
const { name, weight } = lora;
|
||||
const loraField = modelIdToLoRAModelField(name);
|
||||
const currentLoraNodeId = `${LORA_LOADER}_${loraField.model_name.replace(
|
||||
'.',
|
||||
'_'
|
||||
)}`;
|
||||
|
||||
console.log(lastLoraNodeId, currentLoraNodeId, currentLoraIndex, loraField);
|
||||
|
||||
const loraLoaderNode: LoraLoaderInvocation = {
|
||||
type: 'lora_loader',
|
||||
id: currentLoraNodeId,
|
||||
lora: loraField,
|
||||
weight,
|
||||
};
|
||||
|
||||
graph.nodes[currentLoraNodeId] = loraLoaderNode;
|
||||
|
||||
if (currentLoraIndex === 0) {
|
||||
// first lora = start the lora chain, attach directly to model loader
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
node_id: currentLoraNodeId,
|
||||
field: 'unet',
|
||||
},
|
||||
});
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
node_id: currentLoraNodeId,
|
||||
field: 'clip',
|
||||
},
|
||||
});
|
||||
} else {
|
||||
// we are in the middle of the lora chain, instead connect to the previous lora
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: lastLoraNodeId,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
node_id: currentLoraNodeId,
|
||||
field: 'unet',
|
||||
},
|
||||
});
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: lastLoraNodeId,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
node_id: currentLoraNodeId,
|
||||
field: 'clip',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
if (currentLoraIndex === loraCount - 1) {
|
||||
// final lora, end the lora chain - we need to connect up to inference and conditioning nodes
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: currentLoraNodeId,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
node_id: baseNodeId,
|
||||
field: 'unet',
|
||||
},
|
||||
});
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: currentLoraNodeId,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'clip',
|
||||
},
|
||||
});
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: currentLoraNodeId,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'clip',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// increment the lora for the next one in the chain
|
||||
lastLoraNodeId = currentLoraNodeId;
|
||||
currentLoraIndex += 1;
|
||||
});
|
||||
};
|
@ -9,6 +9,7 @@ import {
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import {
|
||||
IMAGE_TO_IMAGE_GRAPH,
|
||||
@ -252,6 +253,8 @@ export const buildCanvasImageToImageGraph = (
|
||||
});
|
||||
}
|
||||
|
||||
addLoRAsToGraph(graph, state, LATENTS_TO_LATENTS);
|
||||
|
||||
// Add VAE
|
||||
addVAEToGraph(graph, state);
|
||||
|
||||
|
@ -3,6 +3,7 @@ import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import {
|
||||
LATENTS_TO_IMAGE,
|
||||
@ -157,6 +158,8 @@ export const buildCanvasTextToImageGraph = (
|
||||
],
|
||||
};
|
||||
|
||||
addLoRAsToGraph(graph, state, TEXT_TO_LATENTS);
|
||||
|
||||
// Add VAE
|
||||
addVAEToGraph(graph, state);
|
||||
|
||||
|
@ -10,6 +10,7 @@ import {
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import {
|
||||
IMAGE_COLLECTION,
|
||||
@ -304,6 +305,9 @@ export const buildLinearImageToImageGraph = (
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
addLoRAsToGraph(graph, state, LATENTS_TO_LATENTS);
|
||||
|
||||
// Add VAE
|
||||
addVAEToGraph(graph, state);
|
||||
|
||||
|
@ -3,6 +3,7 @@ import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import {
|
||||
LATENTS_TO_IMAGE,
|
||||
@ -150,6 +151,8 @@ export const buildLinearTextToImageGraph = (
|
||||
],
|
||||
};
|
||||
|
||||
addLoRAsToGraph(graph, state, TEXT_TO_LATENTS);
|
||||
|
||||
// Add Custom VAE Support
|
||||
addVAEToGraph(graph, state);
|
||||
|
||||
|
@ -4,6 +4,7 @@ import { cloneDeep, omit, reduce } from 'lodash-es';
|
||||
import { Graph } from 'services/api/types';
|
||||
import { AnyInvocation } from 'services/events/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { modelIdToLoRAModelField } from '../modelIdToLoRAName';
|
||||
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||
import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
|
||||
|
||||
@ -38,6 +39,12 @@ export const parseFieldValue = (field: InputFieldValue) => {
|
||||
}
|
||||
}
|
||||
|
||||
if (field.type === 'lora_model') {
|
||||
if (field.value) {
|
||||
return modelIdToLoRAModelField(field.value);
|
||||
}
|
||||
}
|
||||
|
||||
return field.value;
|
||||
};
|
||||
|
||||
|
@ -9,6 +9,7 @@ export const RANGE_OF_SIZE = 'range_of_size';
|
||||
export const ITERATE = 'iterate';
|
||||
export const MAIN_MODEL_LOADER = 'main_model_loader';
|
||||
export const VAE_LOADER = 'vae_loader';
|
||||
export const LORA_LOADER = 'lora_loader';
|
||||
export const IMAGE_TO_LATENTS = 'image_to_latents';
|
||||
export const LATENTS_TO_LATENTS = 'latents_to_latents';
|
||||
export const RESIZE = 'resize_image';
|
||||
|
@ -0,0 +1,12 @@
|
||||
import { BaseModelType, LoRAModelField } from 'services/api/types';
|
||||
|
||||
export const modelIdToLoRAModelField = (loraId: string): LoRAModelField => {
|
||||
const [base_model, model_type, model_name] = loraId.split('/');
|
||||
|
||||
const field: LoRAModelField = {
|
||||
base_model: base_model as BaseModelType,
|
||||
model_name,
|
||||
};
|
||||
|
||||
return field;
|
||||
};
|
@ -1,14 +1,15 @@
|
||||
import { memo } from 'react';
|
||||
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
||||
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
||||
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
||||
import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
|
||||
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
|
||||
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
|
||||
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
|
||||
import ImageToImageTabCoreParameters from './ImageToImageTabCoreParameters';
|
||||
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
|
||||
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
|
||||
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
|
||||
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
|
||||
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
||||
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
||||
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
|
||||
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
|
||||
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
|
||||
import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
|
||||
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
||||
import { memo } from 'react';
|
||||
import ImageToImageTabCoreParameters from './ImageToImageTabCoreParameters';
|
||||
|
||||
const ImageToImageTabParameters = () => {
|
||||
return (
|
||||
@ -17,6 +18,7 @@ const ImageToImageTabParameters = () => {
|
||||
<ParamNegativeConditioning />
|
||||
<ProcessButtons />
|
||||
<ImageToImageTabCoreParameters />
|
||||
<ParamLoraCollapse />
|
||||
<ParamDynamicPromptsCollapse />
|
||||
<ParamControlNetCollapse />
|
||||
<ParamVariationCollapse />
|
||||
|
@ -1,15 +1,16 @@
|
||||
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
|
||||
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
|
||||
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
|
||||
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
||||
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
||||
import ParamHiresCollapse from 'features/parameters/components/Parameters/Hires/ParamHiresCollapse';
|
||||
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
|
||||
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
|
||||
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
|
||||
import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
|
||||
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
||||
import { memo } from 'react';
|
||||
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
||||
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
||||
import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
|
||||
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
|
||||
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
|
||||
import ParamHiresCollapse from 'features/parameters/components/Parameters/Hires/ParamHiresCollapse';
|
||||
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
|
||||
import TextToImageTabCoreParameters from './TextToImageTabCoreParameters';
|
||||
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
|
||||
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
|
||||
|
||||
const TextToImageTabParameters = () => {
|
||||
return (
|
||||
@ -18,6 +19,7 @@ const TextToImageTabParameters = () => {
|
||||
<ParamNegativeConditioning />
|
||||
<ProcessButtons />
|
||||
<TextToImageTabCoreParameters />
|
||||
<ParamLoraCollapse />
|
||||
<ParamDynamicPromptsCollapse />
|
||||
<ParamControlNetCollapse />
|
||||
<ParamVariationCollapse />
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { ModelsList } from 'services/api/types';
|
||||
import { EntityState, createEntityAdapter } from '@reduxjs/toolkit';
|
||||
import { keyBy } from 'lodash-es';
|
||||
import { ModelsList } from 'services/api/types';
|
||||
|
||||
import { ApiFullTagDescription, LIST_TAG, api } from '..';
|
||||
import { paths } from '../schema';
|
||||
@ -24,11 +24,9 @@ export const modelsApi = api.injectEndpoints({
|
||||
listModels: build.query<EntityState<ModelConfig>, ListModelsArg>({
|
||||
query: (arg) => ({ url: 'models/', params: arg }),
|
||||
providesTags: (result, error, arg) => {
|
||||
// any list of boards
|
||||
const tags: ApiFullTagDescription[] = [{ id: 'Model', type: LIST_TAG }];
|
||||
|
||||
if (result) {
|
||||
// and individual tags for each board
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: 'Model' as const,
|
||||
|
@ -35,7 +35,9 @@ export type ModelType = S<'ModelType'>;
|
||||
export type BaseModelType = S<'BaseModelType'>;
|
||||
export type MainModelField = S<'MainModelField'>;
|
||||
export type VAEModelField = S<'VAEModelField'>;
|
||||
export type LoRAModelField = S<'LoRAModelField'>;
|
||||
export type ModelsList = S<'ModelsList'>;
|
||||
export type LoRAModelConfig = S<'LoRAModelConfig'>;
|
||||
|
||||
// Graphs
|
||||
export type Graph = S<'Graph'>;
|
||||
@ -60,6 +62,7 @@ export type ImageToLatentsInvocation = N<'ImageToLatentsInvocation'>;
|
||||
export type LatentsToImageInvocation = N<'LatentsToImageInvocation'>;
|
||||
export type ImageCollectionInvocation = N<'ImageCollectionInvocation'>;
|
||||
export type MainModelLoaderInvocation = N<'MainModelLoaderInvocation'>;
|
||||
export type LoraLoaderInvocation = N<'LoraLoaderInvocation'>;
|
||||
|
||||
// ControlNet Nodes
|
||||
export type ControlNetInvocation = N<'ControlNetInvocation'>;
|
||||
|
Loading…
Reference in New Issue
Block a user