feat(ui): add LoRA ui & update graphs

This commit is contained in:
psychedelicious 2023-07-04 21:12:52 +10:00
parent d537b9f0cb
commit db8862d860
26 changed files with 630 additions and 39 deletions

View File

@ -20,10 +20,8 @@ const serializationDenylist: {
nodes: nodesPersistDenylist,
postprocessing: postprocessingPersistDenylist,
system: systemPersistDenylist,
// config: configPersistDenyList,
ui: uiPersistDenylist,
controlNet: controlNetDenylist,
// hotkeys: hotkeysPersistDenylist,
};
export const serialize: SerializeFunction = (data, key) => {

View File

@ -8,31 +8,32 @@ import {
import dynamicMiddlewares from 'redux-dynamic-middlewares';
import { rememberEnhancer, rememberReducer } from 'redux-remember';
import batchReducer from 'features/batch/store/batchSlice';
import canvasReducer from 'features/canvas/store/canvasSlice';
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice';
import boardsReducer from 'features/gallery/store/boardSlice';
import galleryReducer from 'features/gallery/store/gallerySlice';
import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice';
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
import loraReducer from 'features/lora/store/loraSlice';
import nodesReducer from 'features/nodes/store/nodesSlice';
import generationReducer from 'features/parameters/store/generationSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import systemReducer from 'features/system/store/systemSlice';
import nodesReducer from 'features/nodes/store/nodesSlice';
import boardsReducer from 'features/gallery/store/boardSlice';
import configReducer from 'features/system/store/configSlice';
import systemReducer from 'features/system/store/systemSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import uiReducer from 'features/ui/store/uiSlice';
import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice';
import batchReducer from 'features/batch/store/batchSlice';
import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice';
import { listenerMiddleware } from './middleware/listenerMiddleware';
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
import { api } from 'services/api';
import { LOCALSTORAGE_PREFIX } from './constants';
import { serialize } from './enhancers/reduxRemember/serialize';
import { unserialize } from './enhancers/reduxRemember/unserialize';
import { api } from 'services/api';
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
const allReducers = {
canvas: canvasReducer,
@ -50,6 +51,7 @@ const allReducers = {
dynamicPrompts: dynamicPromptsReducer,
batch: batchReducer,
imageDeletion: imageDeletionReducer,
lora: loraReducer,
[api.reducerPath]: api.reducer,
};
@ -69,6 +71,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'controlNet',
'dynamicPrompts',
'batch',
'lora',
// 'boards',
// 'hotkeys',
// 'config',

View File

@ -0,0 +1,58 @@
import { Flex } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import IAISlider from 'common/components/IAISlider';
import { memo, useCallback } from 'react';
import { FaTrash } from 'react-icons/fa';
import { Lora, loraRemoved, loraWeightChanged } from '../store/loraSlice';
type Props = {
lora: Lora;
};
const ParamLora = (props: Props) => {
const dispatch = useAppDispatch();
const { lora } = props;
const handleChange = useCallback(
(v: number) => {
dispatch(loraWeightChanged({ name: lora.name, weight: v }));
},
[dispatch, lora.name]
);
const handleReset = useCallback(() => {
dispatch(loraWeightChanged({ name: lora.name, weight: 1 }));
}, [dispatch, lora.name]);
const handleRemoveLora = useCallback(() => {
dispatch(loraRemoved(lora.name));
}, [dispatch, lora.name]);
return (
<Flex sx={{ gap: 2.5, alignItems: 'flex-end' }}>
<IAISlider
label={lora.name}
value={lora.weight}
onChange={handleChange}
min={0}
max={1}
step={0.01}
withInput
withReset
handleReset={handleReset}
withSliderMarks
/>
<IAIIconButton
size="sm"
onClick={handleRemoveLora}
tooltip="Remove LoRA"
aria-label="Remove LoRA"
icon={<FaTrash />}
colorScheme="error"
/>
</Flex>
);
};
export default memo(ParamLora);

View File

@ -0,0 +1,20 @@
import { Flex, useDisclosure } from '@chakra-ui/react';
import IAICollapse from 'common/components/IAICollapse';
import { memo } from 'react';
import ParamLoraList from './ParamLoraList';
import ParamLoraSelect from './ParamLoraSelect';
const ParamLoraCollapse = () => {
const { isOpen, onToggle } = useDisclosure();
return (
<IAICollapse label="LoRAs" isOpen={isOpen} onToggle={onToggle}>
<Flex sx={{ flexDir: 'column', gap: 2 }}>
<ParamLoraSelect />
<ParamLoraList />
</Flex>
</IAICollapse>
);
};
export default memo(ParamLoraCollapse);

View File

@ -0,0 +1,19 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { map } from 'lodash-es';
import ParamLora from './ParamLora';
const selector = createSelector(stateSelector, ({ lora }) => {
const { loras } = lora;
return { loras };
});
const ParamLoraList = () => {
const { loras } = useAppSelector(selector);
return map(loras, (lora) => <ParamLora key={lora.name} lora={lora} />);
};
export default ParamLoraList;

View File

@ -0,0 +1,103 @@
import { Text } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
import { forEach } from 'lodash-es';
import { forwardRef, useCallback, useMemo } from 'react';
import { useListModelsQuery } from 'services/api/endpoints/models';
import { loraAdded } from '../store/loraSlice';
type LoraSelectItem = {
label: string;
value: string;
description?: string;
};
const selector = createSelector(
stateSelector,
({ lora }) => ({
loras: lora.loras,
}),
defaultSelectorOptions
);
const ParamLoraSelect = () => {
const dispatch = useAppDispatch();
const { loras } = useAppSelector(selector);
const { data: lorasQueryData } = useListModelsQuery({ model_type: 'lora' });
const data = useMemo(() => {
if (!lorasQueryData) {
return [];
}
const data: LoraSelectItem[] = [];
forEach(lorasQueryData.entities, (lora, id) => {
if (!lora || Boolean(id in loras)) {
return;
}
data.push({
value: id,
label: lora.name,
description: lora.description,
});
});
return data;
}, [loras, lorasQueryData]);
const handleChange = useCallback(
(v: string[]) => {
v[0] && dispatch(loraAdded(v[0]));
},
[dispatch]
);
return (
<IAIMantineMultiSelect
placeholder={data.length === 0 ? 'All LoRAs added' : 'Add LoRA'}
value={[]}
data={data}
maxDropdownHeight={400}
nothingFound="No matching LoRAs"
itemComponent={SelectItem}
disabled={data.length === 0}
filter={(value, selected, item: LoraSelectItem) =>
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim())
}
onChange={handleChange}
/>
);
};
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
value: string;
label: string;
description?: string;
}
const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
({ label, description, ...others }: ItemProps, ref) => {
return (
<div ref={ref} {...others}>
<div>
<Text>{label}</Text>
{description && (
<Text size="xs" color="base.600">
{description}
</Text>
)}
</div>
</div>
);
}
);
SelectItem.displayName = 'SelectItem';
export default ParamLoraSelect;

View File

@ -0,0 +1,44 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
export type Lora = {
name: string;
weight: number;
};
export const defaultLoRAConfig: Omit<Lora, 'name'> = {
weight: 1,
};
export type LoraState = {
loras: Record<string, Lora>;
};
export const intialLoraState: LoraState = {
loras: {},
};
export const loraSlice = createSlice({
name: 'lora',
initialState: intialLoraState,
reducers: {
loraAdded: (state, action: PayloadAction<string>) => {
const name = action.payload;
state.loras[name] = { name, ...defaultLoRAConfig };
},
loraRemoved: (state, action: PayloadAction<string>) => {
const name = action.payload;
delete state.loras[name];
},
loraWeightChanged: (
state,
action: PayloadAction<{ name: string; weight: number }>
) => {
const { name, weight } = action.payload;
state.loras[name].weight = weight;
},
},
});
export const { loraAdded, loraRemoved, loraWeightChanged } = loraSlice.actions;
export default loraSlice.reducer;

View File

@ -12,6 +12,7 @@ import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFie
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
import ItemInputFieldComponent from './fields/ItemInputFieldComponent';
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
import LoRAModelInputFieldComponent from './fields/LoRAModelInputFieldComponent';
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
import StringInputFieldComponent from './fields/StringInputFieldComponent';
@ -163,6 +164,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
);
}
if (type === 'lora_model' && template.type === 'lora_model') {
return (
<LoRAModelInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'array' && template.type === 'array') {
return (
<ArrayInputFieldComponent

View File

@ -0,0 +1,104 @@
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
VaeModelInputFieldTemplate,
VaeModelInputFieldValue,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
import { forEach, isString } from 'lodash-es';
import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useListModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
const LoRAModelInputFieldComponent = (
props: FieldComponentProps<
VaeModelInputFieldValue,
VaeModelInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { data: loraModels } = useListModelsQuery({
model_type: 'lora',
});
const selectedModel = useMemo(
() => loraModels?.entities[field.value ?? loraModels.ids[0]],
[loraModels?.entities, loraModels?.ids, field.value]
);
const data = useMemo(() => {
if (!loraModels) {
return [];
}
const data: SelectItem[] = [];
forEach(loraModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.name,
group: BASE_MODEL_NAME_MAP[model.base_model],
});
});
return data;
}, [loraModels]);
const handleValueChanged = useCallback(
(v: string | null) => {
if (!v) {
return;
}
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: v,
})
);
},
[dispatch, field.name, nodeId]
);
useEffect(() => {
if (field.value && loraModels?.ids.includes(field.value)) {
return;
}
const firstLora = loraModels?.ids[0];
if (!isString(firstLora)) {
return;
}
handleValueChanged(firstLora);
}, [field.value, handleValueChanged, loraModels?.ids]);
return (
<IAIMantineSelect
tooltip={selectedModel?.description}
label={
selectedModel?.base_model &&
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
}
value={field.value}
placeholder="Pick one"
data={data}
onChange={handleValueChanged}
/>
);
};
export default memo(LoRAModelInputFieldComponent);

View File

@ -1,5 +1,8 @@
import { createSlice, PayloadAction } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { cloneDeep, uniqBy } from 'lodash-es';
import { OpenAPIV3 } from 'openapi-types';
import { RgbaColor } from 'react-colorful';
import {
addEdge,
applyEdgeChanges,
@ -11,12 +14,9 @@ import {
NodeChange,
OnConnectStartParams,
} from 'reactflow';
import { ImageField } from 'services/api/types';
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { ImageField } from 'services/api/types';
import { InvocationTemplate, InvocationValue } from '../types/types';
import { RgbaColor } from 'react-colorful';
import { RootState } from 'app/store/store';
import { cloneDeep, isArray, uniq, uniqBy } from 'lodash-es';
export type NodesState = {
nodes: Node<InvocationValue>[];

View File

@ -18,6 +18,7 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
VaeField: 'vae',
model: 'model',
vae_model: 'vae_model',
lora_model: 'lora_model',
array: 'array',
item: 'item',
ColorField: 'color',
@ -120,7 +121,13 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
vae_model: {
color: 'teal',
colorCssVar: getColorTokenCssVariable('teal'),
title: 'Model',
title: 'VAE',
description: 'Models are models.',
},
lora_model: {
color: 'teal',
colorCssVar: getColorTokenCssVariable('teal'),
title: 'LoRA',
description: 'Models are models.',
},
array: {

View File

@ -65,6 +65,7 @@ export type FieldType =
| 'control'
| 'model'
| 'vae_model'
| 'lora_model'
| 'array'
| 'item'
| 'color'
@ -93,6 +94,7 @@ export type InputFieldValue =
| EnumInputFieldValue
| ModelInputFieldValue
| VaeModelInputFieldValue
| LoRAModelInputFieldValue
| ArrayInputFieldValue
| ItemInputFieldValue
| ColorInputFieldValue
@ -119,6 +121,7 @@ export type InputFieldTemplate =
| EnumInputFieldTemplate
| ModelInputFieldTemplate
| VaeModelInputFieldTemplate
| LoRAModelInputFieldTemplate
| ArrayInputFieldTemplate
| ItemInputFieldTemplate
| ColorInputFieldTemplate
@ -236,6 +239,11 @@ export type VaeModelInputFieldValue = FieldValueBase & {
value?: string;
};
export type LoRAModelInputFieldValue = FieldValueBase & {
type: 'lora_model';
value?: string;
};
export type ArrayInputFieldValue = FieldValueBase & {
type: 'array';
value?: (string | number)[];
@ -350,6 +358,11 @@ export type VaeModelInputFieldTemplate = InputFieldTemplateBase & {
type: 'vae_model';
};
export type LoRAModelInputFieldTemplate = InputFieldTemplateBase & {
default: string;
type: 'lora_model';
};
export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
default: [];
type: 'array';

View File

@ -18,6 +18,7 @@ import {
IntegerInputFieldTemplate,
ItemInputFieldTemplate,
LatentsInputFieldTemplate,
LoRAModelInputFieldTemplate,
ModelInputFieldTemplate,
OutputFieldTemplate,
StringInputFieldTemplate,
@ -191,6 +192,21 @@ const buildVaeModelInputFieldTemplate = ({
return template;
};
const buildLoRAModelInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): LoRAModelInputFieldTemplate => {
const template: LoRAModelInputFieldTemplate = {
...baseField,
type: 'lora_model',
inputRequirement: 'always',
inputKind: 'direct',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildImageInputFieldTemplate = ({
schemaObject,
baseField,
@ -460,6 +476,9 @@ export const buildInputFieldTemplate = (
if (['vae_model'].includes(fieldType)) {
return buildVaeModelInputFieldTemplate({ schemaObject, baseField });
}
if (['lora_model'].includes(fieldType)) {
return buildLoRAModelInputFieldTemplate({ schemaObject, baseField });
}
if (['enum'].includes(fieldType)) {
return buildEnumInputFieldTemplate({ schemaObject, baseField });
}

View File

@ -79,6 +79,10 @@ export const buildInputFieldValue = (
if (template.type === 'vae_model') {
fieldValue.value = undefined;
}
if (template.type === 'lora_model') {
fieldValue.value = undefined;
}
}
return fieldValue;

View File

@ -0,0 +1,150 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { forEach, size } from 'lodash-es';
import { LoraLoaderInvocation } from 'services/api/types';
import { modelIdToLoRAModelField } from '../modelIdToLoRAName';
import {
LORA_LOADER,
MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING,
POSITIVE_CONDITIONING,
} from './constants';
export const addLoRAsToGraph = (
graph: NonNullableGraph,
state: RootState,
baseNodeId: string
): void => {
/**
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
* or to the inference/conditioning nodes.
*
* So we need to inject a LoRA chain into the graph.
*/
const { loras } = state.lora;
const loraCount = size(loras);
if (loraCount > 0) {
// remove any existing connections from main model loader, we need to insert the lora nodes
graph.edges = graph.edges.filter(
(e) =>
!(
e.source.node_id === MAIN_MODEL_LOADER &&
['unet', 'clip'].includes(e.source.field)
)
);
}
// we need to remember the last lora so we can chain from it
let lastLoraNodeId = '';
let currentLoraIndex = 0;
forEach(loras, (lora) => {
const { name, weight } = lora;
const loraField = modelIdToLoRAModelField(name);
const currentLoraNodeId = `${LORA_LOADER}_${loraField.model_name.replace(
'.',
'_'
)}`;
console.log(lastLoraNodeId, currentLoraNodeId, currentLoraIndex, loraField);
const loraLoaderNode: LoraLoaderInvocation = {
type: 'lora_loader',
id: currentLoraNodeId,
lora: loraField,
weight,
};
graph.nodes[currentLoraNodeId] = loraLoaderNode;
if (currentLoraIndex === 0) {
// first lora = start the lora chain, attach directly to model loader
graph.edges.push({
source: {
node_id: MAIN_MODEL_LOADER,
field: 'unet',
},
destination: {
node_id: currentLoraNodeId,
field: 'unet',
},
});
graph.edges.push({
source: {
node_id: MAIN_MODEL_LOADER,
field: 'clip',
},
destination: {
node_id: currentLoraNodeId,
field: 'clip',
},
});
} else {
// we are in the middle of the lora chain, instead connect to the previous lora
graph.edges.push({
source: {
node_id: lastLoraNodeId,
field: 'unet',
},
destination: {
node_id: currentLoraNodeId,
field: 'unet',
},
});
graph.edges.push({
source: {
node_id: lastLoraNodeId,
field: 'clip',
},
destination: {
node_id: currentLoraNodeId,
field: 'clip',
},
});
}
if (currentLoraIndex === loraCount - 1) {
// final lora, end the lora chain - we need to connect up to inference and conditioning nodes
graph.edges.push({
source: {
node_id: currentLoraNodeId,
field: 'unet',
},
destination: {
node_id: baseNodeId,
field: 'unet',
},
});
graph.edges.push({
source: {
node_id: currentLoraNodeId,
field: 'clip',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip',
},
});
graph.edges.push({
source: {
node_id: currentLoraNodeId,
field: 'clip',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip',
},
});
}
// increment the lora for the next one in the chain
lastLoraNodeId = currentLoraNodeId;
currentLoraIndex += 1;
});
};

View File

@ -9,6 +9,7 @@ import {
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import {
IMAGE_TO_IMAGE_GRAPH,
@ -252,6 +253,8 @@ export const buildCanvasImageToImageGraph = (
});
}
addLoRAsToGraph(graph, state, LATENTS_TO_LATENTS);
// Add VAE
addVAEToGraph(graph, state);

View File

@ -3,6 +3,7 @@ import { NonNullableGraph } from 'features/nodes/types/types';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import {
LATENTS_TO_IMAGE,
@ -157,6 +158,8 @@ export const buildCanvasTextToImageGraph = (
],
};
addLoRAsToGraph(graph, state, TEXT_TO_LATENTS);
// Add VAE
addVAEToGraph(graph, state);

View File

@ -10,6 +10,7 @@ import {
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import {
IMAGE_COLLECTION,
@ -304,6 +305,9 @@ export const buildLinearImageToImageGraph = (
},
});
}
addLoRAsToGraph(graph, state, LATENTS_TO_LATENTS);
// Add VAE
addVAEToGraph(graph, state);

View File

@ -3,6 +3,7 @@ import { NonNullableGraph } from 'features/nodes/types/types';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import {
LATENTS_TO_IMAGE,
@ -150,6 +151,8 @@ export const buildLinearTextToImageGraph = (
],
};
addLoRAsToGraph(graph, state, TEXT_TO_LATENTS);
// Add Custom VAE Support
addVAEToGraph(graph, state);

View File

@ -4,6 +4,7 @@ import { cloneDeep, omit, reduce } from 'lodash-es';
import { Graph } from 'services/api/types';
import { AnyInvocation } from 'services/events/types';
import { v4 as uuidv4 } from 'uuid';
import { modelIdToLoRAModelField } from '../modelIdToLoRAName';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
@ -38,6 +39,12 @@ export const parseFieldValue = (field: InputFieldValue) => {
}
}
if (field.type === 'lora_model') {
if (field.value) {
return modelIdToLoRAModelField(field.value);
}
}
return field.value;
};

View File

@ -9,6 +9,7 @@ export const RANGE_OF_SIZE = 'range_of_size';
export const ITERATE = 'iterate';
export const MAIN_MODEL_LOADER = 'main_model_loader';
export const VAE_LOADER = 'vae_loader';
export const LORA_LOADER = 'lora_loader';
export const IMAGE_TO_LATENTS = 'image_to_latents';
export const LATENTS_TO_LATENTS = 'latents_to_latents';
export const RESIZE = 'resize_image';

View File

@ -0,0 +1,12 @@
import { BaseModelType, LoRAModelField } from 'services/api/types';
export const modelIdToLoRAModelField = (loraId: string): LoRAModelField => {
const [base_model, model_type, model_name] = loraId.split('/');
const field: LoRAModelField = {
base_model: base_model as BaseModelType,
model_name,
};
return field;
};

View File

@ -1,14 +1,15 @@
import { memo } from 'react';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
import ImageToImageTabCoreParameters from './ImageToImageTabCoreParameters';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
import { memo } from 'react';
import ImageToImageTabCoreParameters from './ImageToImageTabCoreParameters';
const ImageToImageTabParameters = () => {
return (
@ -17,6 +18,7 @@ const ImageToImageTabParameters = () => {
<ParamNegativeConditioning />
<ProcessButtons />
<ImageToImageTabCoreParameters />
<ParamLoraCollapse />
<ParamDynamicPromptsCollapse />
<ParamControlNetCollapse />
<ParamVariationCollapse />

View File

@ -1,15 +1,16 @@
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamHiresCollapse from 'features/parameters/components/Parameters/Hires/ParamHiresCollapse';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
import { memo } from 'react';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
import ParamHiresCollapse from 'features/parameters/components/Parameters/Hires/ParamHiresCollapse';
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
import TextToImageTabCoreParameters from './TextToImageTabCoreParameters';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
const TextToImageTabParameters = () => {
return (
@ -18,6 +19,7 @@ const TextToImageTabParameters = () => {
<ParamNegativeConditioning />
<ProcessButtons />
<TextToImageTabCoreParameters />
<ParamLoraCollapse />
<ParamDynamicPromptsCollapse />
<ParamControlNetCollapse />
<ParamVariationCollapse />

View File

@ -1,6 +1,6 @@
import { ModelsList } from 'services/api/types';
import { EntityState, createEntityAdapter } from '@reduxjs/toolkit';
import { keyBy } from 'lodash-es';
import { ModelsList } from 'services/api/types';
import { ApiFullTagDescription, LIST_TAG, api } from '..';
import { paths } from '../schema';
@ -24,11 +24,9 @@ export const modelsApi = api.injectEndpoints({
listModels: build.query<EntityState<ModelConfig>, ListModelsArg>({
query: (arg) => ({ url: 'models/', params: arg }),
providesTags: (result, error, arg) => {
// any list of boards
const tags: ApiFullTagDescription[] = [{ id: 'Model', type: LIST_TAG }];
if (result) {
// and individual tags for each board
tags.push(
...result.ids.map((id) => ({
type: 'Model' as const,

View File

@ -35,7 +35,9 @@ export type ModelType = S<'ModelType'>;
export type BaseModelType = S<'BaseModelType'>;
export type MainModelField = S<'MainModelField'>;
export type VAEModelField = S<'VAEModelField'>;
export type LoRAModelField = S<'LoRAModelField'>;
export type ModelsList = S<'ModelsList'>;
export type LoRAModelConfig = S<'LoRAModelConfig'>;
// Graphs
export type Graph = S<'Graph'>;
@ -60,6 +62,7 @@ export type ImageToLatentsInvocation = N<'ImageToLatentsInvocation'>;
export type LatentsToImageInvocation = N<'LatentsToImageInvocation'>;
export type ImageCollectionInvocation = N<'ImageCollectionInvocation'>;
export type MainModelLoaderInvocation = N<'MainModelLoaderInvocation'>;
export type LoraLoaderInvocation = N<'LoraLoaderInvocation'>;
// ControlNet Nodes
export type ControlNetInvocation = N<'ControlNetInvocation'>;