feat: Add Embedding Picker to Linear UI (#3654)

This commit is contained in:
blessedcoolant 2023-07-07 00:29:19 +12:00 committed by GitHub
commit 405054d802
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 429 additions and 75 deletions

View File

@ -3,15 +3,13 @@ from __future__ import annotations
import copy
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, Union, List
import torch
from compel.embeddings_provider import BaseTextualInversionManager
from diffusers.models import UNet2DConditionModel
from safetensors.torch import load_file
from torch.utils.hooks import RemovableHandle
from transformers import CLIPTextModel
from transformers import CLIPTextModel, CLIPTokenizer
class LoRALayerBase:
#rank: Optional[int]
@ -123,8 +121,8 @@ class LoRALayer(LoRALayerBase):
def get_weight(self):
if self.mid is not None:
up = self.up.reshape(up.shape[0], up.shape[1])
down = self.down.reshape(up.shape[0], up.shape[1])
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
else:
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
@ -410,7 +408,7 @@ class LoRAModel: #(torch.nn.Module):
else:
# TODO: diff/ia3/... format
print(
f">> Encountered unknown lora layer module in {self.name}: {layer_key}"
f">> Encountered unknown lora layer module in {model.name}: {layer_key}"
)
return

View File

@ -1,15 +1,16 @@
import { Tooltip, useColorMode, useToken } from '@chakra-ui/react';
import { MultiSelect, MultiSelectProps } from '@mantine/core';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { memo } from 'react';
import { RefObject, memo } from 'react';
import { mode } from 'theme/util/mode';
type IAIMultiSelectProps = MultiSelectProps & {
tooltip?: string;
inputRef?: RefObject<HTMLInputElement>;
};
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
const { searchable = true, tooltip, ...rest } = props;
const { searchable = true, tooltip, inputRef, ...rest } = props;
const {
base50,
base100,
@ -33,6 +34,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
return (
<Tooltip label={tooltip} placement="top" hasArrow>
<MultiSelect
ref={inputRef}
searchable={searchable}
styles={() => ({
label: {

View File

@ -0,0 +1,33 @@
import IAIIconButton from 'common/components/IAIIconButton';
import { memo } from 'react';
import { BiCode } from 'react-icons/bi';
type Props = {
onClick: () => void;
};
const AddEmbeddingButton = (props: Props) => {
const { onClick } = props;
return (
<IAIIconButton
size="sm"
aria-label="Add Embedding"
tooltip="Add Embedding"
icon={<BiCode />}
sx={{
p: 2,
color: 'base.700',
_hover: {
color: 'base.550',
},
_active: {
color: 'base.500',
},
}}
variant="link"
onClick={onClick}
/>
);
};
export default memo(AddEmbeddingButton);

View File

@ -0,0 +1,151 @@
import {
Flex,
Popover,
PopoverBody,
PopoverContent,
PopoverTrigger,
Text,
} from '@chakra-ui/react';
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
import { forEach } from 'lodash-es';
import {
PropsWithChildren,
forwardRef,
useCallback,
useMemo,
useRef,
} from 'react';
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
type EmbeddingSelectItem = {
label: string;
value: string;
description?: string;
};
type Props = PropsWithChildren & {
onSelect: (v: string) => void;
isOpen: boolean;
onClose: () => void;
};
const ParamEmbeddingPopover = (props: Props) => {
const { onSelect, isOpen, onClose, children } = props;
const { data: embeddingQueryData } = useGetTextualInversionModelsQuery();
const inputRef = useRef<HTMLInputElement>(null);
const data = useMemo(() => {
if (!embeddingQueryData) {
return [];
}
const data: EmbeddingSelectItem[] = [];
forEach(embeddingQueryData.entities, (embedding, _) => {
if (!embedding) return;
data.push({
value: embedding.name,
label: embedding.name,
description: embedding.description,
});
});
return data;
}, [embeddingQueryData]);
const handleChange = useCallback(
(v: string[]) => {
if (v.length === 0) {
return;
}
onSelect(v[0]);
},
[onSelect]
);
return (
<Popover
initialFocusRef={inputRef}
isOpen={isOpen}
onClose={onClose}
placement="bottom"
openDelay={0}
closeDelay={0}
closeOnBlur={true}
returnFocusOnClose={true}
>
<PopoverTrigger>{children}</PopoverTrigger>
<PopoverContent
sx={{
p: 0,
top: -1,
shadow: 'dark-lg',
borderColor: 'accent.300',
borderWidth: '2px',
borderStyle: 'solid',
_dark: { borderColor: 'accent.400' },
}}
>
<PopoverBody
sx={{ p: 0, w: `calc(${PARAMETERS_PANEL_WIDTH} - 2rem )` }}
>
{data.length === 0 ? (
<Flex sx={{ justifyContent: 'center', p: 2 }}>
<Text
sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}
>
No Embeddings Loaded
</Text>
</Flex>
) : (
<IAIMantineMultiSelect
inputRef={inputRef}
placeholder={'Add Embedding'}
value={[]}
data={data}
maxDropdownHeight={400}
nothingFound="No Matching Embeddings"
itemComponent={SelectItem}
disabled={data.length === 0}
filter={(value, selected, item: EmbeddingSelectItem) =>
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim())
}
onChange={handleChange}
/>
)}
</PopoverBody>
</PopoverContent>
</Popover>
);
};
export default ParamEmbeddingPopover;
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
value: string;
label: string;
description?: string;
}
const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
({ label, description, ...others }: ItemProps, ref) => {
return (
<div ref={ref} {...others}>
<div>
<Text>{label}</Text>
{description && (
<Text size="xs" color="base.600">
{description}
</Text>
)}
</div>
</div>
);
}
);
SelectItem.displayName = 'SelectItem';

View File

@ -4,7 +4,12 @@ import IAIIconButton from 'common/components/IAIIconButton';
import IAISlider from 'common/components/IAISlider';
import { memo, useCallback } from 'react';
import { FaTrash } from 'react-icons/fa';
import { Lora, loraRemoved, loraWeightChanged } from '../store/loraSlice';
import {
Lora,
loraRemoved,
loraWeightChanged,
loraWeightReset,
} from '../store/loraSlice';
type Props = {
lora: Lora;
@ -22,7 +27,7 @@ const ParamLora = (props: Props) => {
);
const handleReset = useCallback(() => {
dispatch(loraWeightChanged({ id: lora.id, weight: 1 }));
dispatch(loraWeightReset(lora.id));
}, [dispatch, lora.id]);
const handleRemoveLora = useCallback(() => {

View File

@ -1,4 +1,4 @@
import { Text } from '@chakra-ui/react';
import { Flex, Text } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@ -61,6 +61,16 @@ const ParamLoraSelect = () => {
[dispatch, lorasQueryData?.entities]
);
if (lorasQueryData?.ids.length === 0) {
return (
<Flex sx={{ justifyContent: 'center', p: 2 }}>
<Text sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}>
No LoRAs Loaded
</Text>
</Flex>
);
}
return (
<IAIMantineMultiSelect
placeholder={data.length === 0 ? 'All LoRAs added' : 'Add LoRA'}

View File

@ -8,7 +8,7 @@ export type Lora = {
};
export const defaultLoRAConfig: Omit<Lora, 'id' | 'name'> = {
weight: 1,
weight: 0.75,
};
export type LoraState = {
@ -38,9 +38,14 @@ export const loraSlice = createSlice({
const { id, weight } = action.payload;
state.loras[id].weight = weight;
},
loraWeightReset: (state, action: PayloadAction<string>) => {
const id = action.payload;
state.loras[id].weight = defaultLoRAConfig.weight;
},
},
});
export const { loraAdded, loraRemoved, loraWeightChanged } = loraSlice.actions;
export const { loraAdded, loraRemoved, loraWeightChanged, loraWeightReset } =
loraSlice.actions;
export default loraSlice.reducer;

View File

@ -1,29 +1,107 @@
import { FormControl } from '@chakra-ui/react';
import { Box, FormControl, useDisclosure } from '@chakra-ui/react';
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAITextarea from 'common/components/IAITextarea';
import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton';
import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover';
import { setNegativePrompt } from 'features/parameters/store/generationSlice';
import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react';
import { flushSync } from 'react-dom';
import { useTranslation } from 'react-i18next';
const ParamNegativeConditioning = () => {
const negativePrompt = useAppSelector(
(state: RootState) => state.generation.negativePrompt
);
const promptRef = useRef<HTMLTextAreaElement>(null);
const { isOpen, onClose, onOpen } = useDisclosure();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleChangePrompt = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(setNegativePrompt(e.target.value));
},
[dispatch]
);
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLTextAreaElement>) => {
if (e.key === '<') {
onOpen();
}
},
[onOpen]
);
const handleSelectEmbedding = useCallback(
(v: string) => {
if (!promptRef.current) {
return;
}
// this is where we insert the TI trigger
const caret = promptRef.current.selectionStart;
if (caret === undefined) {
return;
}
let newPrompt = negativePrompt.slice(0, caret);
if (newPrompt[newPrompt.length - 1] !== '<') {
newPrompt += '<';
}
newPrompt += `${v}>`;
// we insert the cursor after the `>`
const finalCaretPos = newPrompt.length;
newPrompt += negativePrompt.slice(caret);
// must flush dom updates else selection gets reset
flushSync(() => {
dispatch(setNegativePrompt(newPrompt));
});
// set the caret position to just after the TI trigger promptRef.current.selectionStart = finalCaretPos;
promptRef.current.selectionEnd = finalCaretPos;
onClose();
},
[dispatch, onClose, negativePrompt]
);
return (
<FormControl>
<IAITextarea
id="negativePrompt"
name="negativePrompt"
value={negativePrompt}
onChange={(e) => dispatch(setNegativePrompt(e.target.value))}
placeholder={t('parameters.negativePromptPlaceholder')}
fontSize="sm"
minH={16}
/>
<ParamEmbeddingPopover
isOpen={isOpen}
onClose={onClose}
onSelect={handleSelectEmbedding}
>
<IAITextarea
id="negativePrompt"
name="negativePrompt"
ref={promptRef}
value={negativePrompt}
placeholder={t('parameters.negativePromptPlaceholder')}
onChange={handleChangePrompt}
onKeyDown={handleKeyDown}
resize="vertical"
fontSize="sm"
minH={16}
/>
</ParamEmbeddingPopover>
{!isOpen && (
<Box
sx={{
position: 'absolute',
top: 0,
insetInlineEnd: 0,
}}
>
<AddEmbeddingButton onClick={onOpen} />
</Box>
)}
</FormControl>
);
};

View File

@ -1,4 +1,4 @@
import { Box, FormControl } from '@chakra-ui/react';
import { Box, FormControl, useDisclosure } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react';
@ -11,12 +11,15 @@ import {
} from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash-es';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { userInvoked } from 'app/store/actions';
import IAITextarea from 'common/components/IAITextarea';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton';
import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover';
import { isEqual } from 'lodash-es';
import { flushSync } from 'react-dom';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
const promptInputSelector = createSelector(
[(state: RootState) => state.generation, activeTabNameSelector],
@ -40,14 +43,15 @@ const ParamPositiveConditioning = () => {
const dispatch = useAppDispatch();
const { prompt, activeTabName } = useAppSelector(promptInputSelector);
const isReady = useIsReadyToInvoke();
const promptRef = useRef<HTMLTextAreaElement>(null);
const { isOpen, onClose, onOpen } = useDisclosure();
const { t } = useTranslation();
const handleChangePrompt = (e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(setPositivePrompt(e.target.value));
};
const handleChangePrompt = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(setPositivePrompt(e.target.value));
},
[dispatch]
);
useHotkeys(
'alt+a',
@ -57,6 +61,45 @@ const ParamPositiveConditioning = () => {
[]
);
const handleSelectEmbedding = useCallback(
(v: string) => {
if (!promptRef.current) {
return;
}
// this is where we insert the TI trigger
const caret = promptRef.current.selectionStart;
if (caret === undefined) {
return;
}
let newPrompt = prompt.slice(0, caret);
if (newPrompt[newPrompt.length - 1] !== '<') {
newPrompt += '<';
}
newPrompt += `${v}>`;
// we insert the cursor after the `>`
const finalCaretPos = newPrompt.length;
newPrompt += prompt.slice(caret);
// must flush dom updates else selection gets reset
flushSync(() => {
dispatch(setPositivePrompt(newPrompt));
});
// set the caret position to just after the TI trigger
promptRef.current.selectionStart = finalCaretPos;
promptRef.current.selectionEnd = finalCaretPos;
onClose();
},
[dispatch, onClose, prompt]
);
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLTextAreaElement>) => {
if (e.key === 'Enter' && e.shiftKey === false && isReady) {
@ -64,25 +107,50 @@ const ParamPositiveConditioning = () => {
dispatch(clampSymmetrySteps());
dispatch(userInvoked(activeTabName));
}
if (e.key === '<') {
onOpen();
}
},
[dispatch, activeTabName, isReady]
[isReady, dispatch, activeTabName, onOpen]
);
// const handleSelect = (e: MouseEvent<HTMLTextAreaElement>) => {
// const target = e.target as HTMLTextAreaElement;
// setCaret({ start: target.selectionStart, end: target.selectionEnd });
// };
return (
<Box>
<FormControl>
<IAITextarea
id="prompt"
name="prompt"
placeholder={t('parameters.positivePromptPlaceholder')}
value={prompt}
onChange={handleChangePrompt}
onKeyDown={handleKeyDown}
resize="vertical"
ref={promptRef}
minH={32}
/>
<ParamEmbeddingPopover
isOpen={isOpen}
onClose={onClose}
onSelect={handleSelectEmbedding}
>
<IAITextarea
id="prompt"
name="prompt"
ref={promptRef}
value={prompt}
placeholder={t('parameters.positivePromptPlaceholder')}
onChange={handleChangePrompt}
onKeyDown={handleKeyDown}
resize="vertical"
minH={32}
/>
</ParamEmbeddingPopover>
</FormControl>
{!isOpen && (
<Box
sx={{
position: 'absolute',
top: 6,
insetInlineEnd: 0,
}}
>
<AddEmbeddingButton onClick={onOpen} />
</Box>
)}
</Box>
);
};

View File

@ -1,4 +1,3 @@
import { Tooltip } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton, {
IAIIconButtonProps,
@ -25,26 +24,25 @@ const PinParametersPanelButton = (props: PinParametersPanelButtonProps) => {
};
return (
<Tooltip label={t('common.pinOptionsPanel')}>
<IAIIconButton
{...props}
aria-label={t('common.pinOptionsPanel')}
onClick={handleClickPinOptionsPanel}
icon={shouldPinParametersPanel ? <BsPinAngleFill /> : <BsPinAngle />}
variant="ghost"
size="sm"
sx={{
color: 'base.700',
_hover: {
color: 'base.550',
},
_active: {
color: 'base.500',
},
...sx,
}}
/>
</Tooltip>
<IAIIconButton
{...props}
tooltip={t('common.pinOptionsPanel')}
aria-label={t('common.pinOptionsPanel')}
onClick={handleClickPinOptionsPanel}
icon={shouldPinParametersPanel ? <BsPinAngleFill /> : <BsPinAngle />}
variant="ghost"
size="sm"
sx={{
color: 'base.700',
_hover: {
color: 'base.550',
},
_active: {
color: 'base.500',
},
...sx,
}}
/>
);
};

View File

@ -1,10 +1,10 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
import { setActiveTabReducer } from './extraReducers';
import { InvokeTabName } from './tabMap';
import { AddNewModelType, UIState } from './uiTypes';
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
export const initialUIState: UIState = {
activeTab: 0,
@ -19,6 +19,7 @@ export const initialUIState: UIState = {
shouldShowGallery: true,
shouldHidePreview: false,
shouldShowProgressInViewer: true,
shouldShowEmbeddingPicker: false,
favoriteSchedulers: [],
};
@ -96,6 +97,9 @@ export const uiSlice = createSlice({
) => {
state.favoriteSchedulers = action.payload;
},
toggleEmbeddingPicker: (state) => {
state.shouldShowEmbeddingPicker = !state.shouldShowEmbeddingPicker;
},
},
extraReducers(builder) {
builder.addCase(initialImageChanged, (state) => {
@ -122,6 +126,7 @@ export const {
toggleGalleryPanel,
setShouldShowProgressInViewer,
favoriteSchedulersChanged,
toggleEmbeddingPicker,
} = uiSlice.actions;
export default uiSlice.reducer;

View File

@ -27,5 +27,6 @@ export interface UIState {
shouldPinGallery: boolean;
shouldShowGallery: boolean;
shouldShowProgressInViewer: boolean;
shouldShowEmbeddingPicker: boolean;
favoriteSchedulers: SchedulerParam[];
}

View File

@ -1,18 +1,18 @@
import { Middleware, MiddlewareAPI } from '@reduxjs/toolkit';
import { io, Socket } from 'socket.io-client';
import { Socket, io } from 'socket.io-client';
import { AppThunkDispatch, RootState } from 'app/store/store';
import { getTimestamp } from 'common/util/getTimestamp';
import { sessionCreated } from 'services/api/thunks/session';
import {
ClientToServerEvents,
ServerToClientEvents,
} from 'services/events/types';
import { socketSubscribed, socketUnsubscribed } from './actions';
import { AppThunkDispatch, RootState } from 'app/store/store';
import { getTimestamp } from 'common/util/getTimestamp';
import { sessionCreated } from 'services/api/thunks/session';
// import { OpenAPI } from 'services/api/types';
import { setEventListeners } from 'services/events/util/setEventListeners';
import { log } from 'app/logging/useLogger';
import { $authToken, $baseUrl } from 'services/api/client';
import { setEventListeners } from 'services/events/util/setEventListeners';
const socketioLog = log.child({ namespace: 'socketio' });
@ -88,7 +88,7 @@ export const socketMiddleware = () => {
socketSubscribed({
sessionId: sessionId,
timestamp: getTimestamp(),
boardId: getState().boards.selectedBoardId,
boardId: getState().gallery.selectedBoardId,
})
);
}