mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into lstein/configure-max-cache-size
This commit is contained in:
commit
b229fe19aa
@ -193,7 +193,10 @@ class ModelInstall(object):
|
||||
models_installed.update(self._install_path(path))
|
||||
|
||||
# folders style or similar
|
||||
elif path.is_dir() and any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
||||
elif path.is_dir() and any([(path/x).exists() for x in \
|
||||
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
|
||||
]
|
||||
):
|
||||
models_installed.update(self._install_path(path))
|
||||
|
||||
# recursive scan
|
||||
|
@ -3,15 +3,13 @@ from __future__ import annotations
|
||||
import copy
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple, Union, List
|
||||
|
||||
import torch
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from safetensors.torch import load_file
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
class LoRALayerBase:
|
||||
#rank: Optional[int]
|
||||
@ -123,8 +121,8 @@ class LoRALayer(LoRALayerBase):
|
||||
|
||||
def get_weight(self):
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(up.shape[0], up.shape[1])
|
||||
down = self.down.reshape(up.shape[0], up.shape[1])
|
||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||
else:
|
||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||
@ -410,7 +408,7 @@ class LoRAModel: #(torch.nn.Module):
|
||||
else:
|
||||
# TODO: diff/ia3/... format
|
||||
print(
|
||||
f">> Encountered unknown lora layer module in {self.name}: {layer_key}"
|
||||
f">> Encountered unknown lora layer module in {model.name}: {layer_key}"
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -785,7 +785,7 @@ class ModelManager(object):
|
||||
if path in known_paths or path.parent in scanned_dirs:
|
||||
scanned_dirs.add(path)
|
||||
continue
|
||||
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
||||
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
|
||||
new_models_found.update(installer.heuristic_import(path))
|
||||
scanned_dirs.add(path)
|
||||
|
||||
@ -794,7 +794,8 @@ class ModelManager(object):
|
||||
if path in known_paths or path.parent in scanned_dirs:
|
||||
continue
|
||||
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
|
||||
new_models_found.update(installer.heuristic_import(path))
|
||||
import_result = installer.heuristic_import(path)
|
||||
new_models_found.update(import_result)
|
||||
|
||||
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
|
||||
installed.update(new_models_found)
|
||||
|
@ -78,7 +78,6 @@ class ModelProbe(object):
|
||||
format_type = 'diffusers' if model_path.is_dir() else 'checkpoint'
|
||||
else:
|
||||
format_type = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
|
||||
|
||||
model_info = None
|
||||
try:
|
||||
model_type = cls.get_model_type_from_folder(model_path, model) \
|
||||
@ -105,7 +104,7 @@ class ModelProbe(object):
|
||||
) else 512,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
raise
|
||||
|
||||
return model_info
|
||||
|
||||
@ -127,6 +126,8 @@ class ModelProbe(object):
|
||||
return ModelType.Vae
|
||||
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||
return ModelType.Lora
|
||||
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
||||
return ModelType.Lora
|
||||
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
||||
return ModelType.ControlNet
|
||||
elif key in {"emb_params", "string_to_param"}:
|
||||
@ -137,7 +138,7 @@ class ModelProbe(object):
|
||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||
return ModelType.TextualInversion
|
||||
|
||||
raise ValueError("Unable to determine model type")
|
||||
raise ValueError(f"Unable to determine model type for {model_path}")
|
||||
|
||||
@classmethod
|
||||
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
|
||||
@ -167,7 +168,7 @@ class ModelProbe(object):
|
||||
return type
|
||||
|
||||
# give up
|
||||
raise ValueError("Unable to determine model type")
|
||||
raise ValueError("Unable to determine model type for {folder_path}")
|
||||
|
||||
@classmethod
|
||||
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
|
||||
|
@ -678,9 +678,8 @@ def select_and_download_models(opt: Namespace):
|
||||
|
||||
# this is where the TUI is called
|
||||
else:
|
||||
# needed because the torch library is loaded, even though we don't use it
|
||||
# currently commented out because it has started generating errors (?)
|
||||
# torch.multiprocessing.set_start_method("spawn")
|
||||
# needed to support the probe() method running under a subprocess
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
|
||||
# the third argument is needed in the Windows 11 environment in
|
||||
# order to launch and resize a console window running this program
|
||||
|
@ -1,15 +1,16 @@
|
||||
import { Tooltip, useColorMode, useToken } from '@chakra-ui/react';
|
||||
import { MultiSelect, MultiSelectProps } from '@mantine/core';
|
||||
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
|
||||
import { memo } from 'react';
|
||||
import { RefObject, memo } from 'react';
|
||||
import { mode } from 'theme/util/mode';
|
||||
|
||||
type IAIMultiSelectProps = MultiSelectProps & {
|
||||
tooltip?: string;
|
||||
inputRef?: RefObject<HTMLInputElement>;
|
||||
};
|
||||
|
||||
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
||||
const { searchable = true, tooltip, ...rest } = props;
|
||||
const { searchable = true, tooltip, inputRef, ...rest } = props;
|
||||
const {
|
||||
base50,
|
||||
base100,
|
||||
@ -33,6 +34,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
||||
return (
|
||||
<Tooltip label={tooltip} placement="top" hasArrow>
|
||||
<MultiSelect
|
||||
ref={inputRef}
|
||||
searchable={searchable}
|
||||
styles={() => ({
|
||||
label: {
|
||||
|
@ -0,0 +1,33 @@
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { memo } from 'react';
|
||||
import { BiCode } from 'react-icons/bi';
|
||||
|
||||
type Props = {
|
||||
onClick: () => void;
|
||||
};
|
||||
|
||||
const AddEmbeddingButton = (props: Props) => {
|
||||
const { onClick } = props;
|
||||
return (
|
||||
<IAIIconButton
|
||||
size="sm"
|
||||
aria-label="Add Embedding"
|
||||
tooltip="Add Embedding"
|
||||
icon={<BiCode />}
|
||||
sx={{
|
||||
p: 2,
|
||||
color: 'base.700',
|
||||
_hover: {
|
||||
color: 'base.550',
|
||||
},
|
||||
_active: {
|
||||
color: 'base.500',
|
||||
},
|
||||
}}
|
||||
variant="link"
|
||||
onClick={onClick}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(AddEmbeddingButton);
|
@ -0,0 +1,151 @@
|
||||
import {
|
||||
Flex,
|
||||
Popover,
|
||||
PopoverBody,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
Text,
|
||||
} from '@chakra-ui/react';
|
||||
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
||||
import { forEach } from 'lodash-es';
|
||||
import {
|
||||
PropsWithChildren,
|
||||
forwardRef,
|
||||
useCallback,
|
||||
useMemo,
|
||||
useRef,
|
||||
} from 'react';
|
||||
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
|
||||
|
||||
type EmbeddingSelectItem = {
|
||||
label: string;
|
||||
value: string;
|
||||
description?: string;
|
||||
};
|
||||
|
||||
type Props = PropsWithChildren & {
|
||||
onSelect: (v: string) => void;
|
||||
isOpen: boolean;
|
||||
onClose: () => void;
|
||||
};
|
||||
|
||||
const ParamEmbeddingPopover = (props: Props) => {
|
||||
const { onSelect, isOpen, onClose, children } = props;
|
||||
const { data: embeddingQueryData } = useGetTextualInversionModelsQuery();
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!embeddingQueryData) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: EmbeddingSelectItem[] = [];
|
||||
|
||||
forEach(embeddingQueryData.entities, (embedding, _) => {
|
||||
if (!embedding) return;
|
||||
|
||||
data.push({
|
||||
value: embedding.name,
|
||||
label: embedding.name,
|
||||
description: embedding.description,
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [embeddingQueryData]);
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: string[]) => {
|
||||
if (v.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
onSelect(v[0]);
|
||||
},
|
||||
[onSelect]
|
||||
);
|
||||
|
||||
return (
|
||||
<Popover
|
||||
initialFocusRef={inputRef}
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
placement="bottom"
|
||||
openDelay={0}
|
||||
closeDelay={0}
|
||||
closeOnBlur={true}
|
||||
returnFocusOnClose={true}
|
||||
>
|
||||
<PopoverTrigger>{children}</PopoverTrigger>
|
||||
<PopoverContent
|
||||
sx={{
|
||||
p: 0,
|
||||
top: -1,
|
||||
shadow: 'dark-lg',
|
||||
borderColor: 'accent.300',
|
||||
borderWidth: '2px',
|
||||
borderStyle: 'solid',
|
||||
_dark: { borderColor: 'accent.400' },
|
||||
}}
|
||||
>
|
||||
<PopoverBody
|
||||
sx={{ p: 0, w: `calc(${PARAMETERS_PANEL_WIDTH} - 2rem )` }}
|
||||
>
|
||||
{data.length === 0 ? (
|
||||
<Flex sx={{ justifyContent: 'center', p: 2 }}>
|
||||
<Text
|
||||
sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}
|
||||
>
|
||||
No Embeddings Loaded
|
||||
</Text>
|
||||
</Flex>
|
||||
) : (
|
||||
<IAIMantineMultiSelect
|
||||
inputRef={inputRef}
|
||||
placeholder={'Add Embedding'}
|
||||
value={[]}
|
||||
data={data}
|
||||
maxDropdownHeight={400}
|
||||
nothingFound="No Matching Embeddings"
|
||||
itemComponent={SelectItem}
|
||||
disabled={data.length === 0}
|
||||
filter={(value, selected, item: EmbeddingSelectItem) =>
|
||||
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||
item.value.toLowerCase().includes(value.toLowerCase().trim())
|
||||
}
|
||||
onChange={handleChange}
|
||||
/>
|
||||
)}
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
};
|
||||
|
||||
export default ParamEmbeddingPopover;
|
||||
|
||||
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
|
||||
value: string;
|
||||
label: string;
|
||||
description?: string;
|
||||
}
|
||||
|
||||
const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
|
||||
({ label, description, ...others }: ItemProps, ref) => {
|
||||
return (
|
||||
<div ref={ref} {...others}>
|
||||
<div>
|
||||
<Text>{label}</Text>
|
||||
{description && (
|
||||
<Text size="xs" color="base.600">
|
||||
{description}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
SelectItem.displayName = 'SelectItem';
|
@ -23,6 +23,7 @@ export const makeSelector = (image_name: string) =>
|
||||
({ gallery }) => {
|
||||
const isSelected = gallery.selection.includes(image_name);
|
||||
const selectionCount = gallery.selection.length;
|
||||
|
||||
return {
|
||||
isSelected,
|
||||
selectionCount,
|
||||
@ -117,7 +118,7 @@ const GalleryImage = (props: HoverableImageProps) => {
|
||||
resetIcon={<FaTrash />}
|
||||
resetTooltip="Delete image"
|
||||
imageSx={{ w: 'full', h: 'full' }}
|
||||
withResetIcon
|
||||
// withResetIcon // removed bc it's too easy to accidentally delete images
|
||||
isDropDisabled={true}
|
||||
isUploadDisabled={true}
|
||||
/>
|
||||
|
@ -4,7 +4,12 @@ import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { FaTrash } from 'react-icons/fa';
|
||||
import { Lora, loraRemoved, loraWeightChanged } from '../store/loraSlice';
|
||||
import {
|
||||
Lora,
|
||||
loraRemoved,
|
||||
loraWeightChanged,
|
||||
loraWeightReset,
|
||||
} from '../store/loraSlice';
|
||||
|
||||
type Props = {
|
||||
lora: Lora;
|
||||
@ -22,7 +27,7 @@ const ParamLora = (props: Props) => {
|
||||
);
|
||||
|
||||
const handleReset = useCallback(() => {
|
||||
dispatch(loraWeightChanged({ id: lora.id, weight: 1 }));
|
||||
dispatch(loraWeightReset(lora.id));
|
||||
}, [dispatch, lora.id]);
|
||||
|
||||
const handleRemoveLora = useCallback(() => {
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { Text } from '@chakra-ui/react';
|
||||
import { Flex, Text } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
@ -61,6 +61,16 @@ const ParamLoraSelect = () => {
|
||||
[dispatch, lorasQueryData?.entities]
|
||||
);
|
||||
|
||||
if (lorasQueryData?.ids.length === 0) {
|
||||
return (
|
||||
<Flex sx={{ justifyContent: 'center', p: 2 }}>
|
||||
<Text sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}>
|
||||
No LoRAs Loaded
|
||||
</Text>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<IAIMantineMultiSelect
|
||||
placeholder={data.length === 0 ? 'All LoRAs added' : 'Add LoRA'}
|
||||
|
@ -8,7 +8,7 @@ export type Lora = {
|
||||
};
|
||||
|
||||
export const defaultLoRAConfig: Omit<Lora, 'id' | 'name'> = {
|
||||
weight: 1,
|
||||
weight: 0.75,
|
||||
};
|
||||
|
||||
export type LoraState = {
|
||||
@ -38,9 +38,14 @@ export const loraSlice = createSlice({
|
||||
const { id, weight } = action.payload;
|
||||
state.loras[id].weight = weight;
|
||||
},
|
||||
loraWeightReset: (state, action: PayloadAction<string>) => {
|
||||
const id = action.payload;
|
||||
state.loras[id].weight = defaultLoRAConfig.weight;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
export const { loraAdded, loraRemoved, loraWeightChanged } = loraSlice.actions;
|
||||
export const { loraAdded, loraRemoved, loraWeightChanged, loraWeightReset } =
|
||||
loraSlice.actions;
|
||||
|
||||
export default loraSlice.reducer;
|
||||
|
@ -1,29 +1,107 @@
|
||||
import { FormControl } from '@chakra-ui/react';
|
||||
import { Box, FormControl, useDisclosure } from '@chakra-ui/react';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAITextarea from 'common/components/IAITextarea';
|
||||
import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton';
|
||||
import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover';
|
||||
import { setNegativePrompt } from 'features/parameters/store/generationSlice';
|
||||
import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react';
|
||||
import { flushSync } from 'react-dom';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const ParamNegativeConditioning = () => {
|
||||
const negativePrompt = useAppSelector(
|
||||
(state: RootState) => state.generation.negativePrompt
|
||||
);
|
||||
|
||||
const promptRef = useRef<HTMLTextAreaElement>(null);
|
||||
const { isOpen, onClose, onOpen } = useDisclosure();
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleChangePrompt = useCallback(
|
||||
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
dispatch(setNegativePrompt(e.target.value));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const handleKeyDown = useCallback(
|
||||
(e: KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
if (e.key === '<') {
|
||||
onOpen();
|
||||
}
|
||||
},
|
||||
[onOpen]
|
||||
);
|
||||
|
||||
const handleSelectEmbedding = useCallback(
|
||||
(v: string) => {
|
||||
if (!promptRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
// this is where we insert the TI trigger
|
||||
const caret = promptRef.current.selectionStart;
|
||||
|
||||
if (caret === undefined) {
|
||||
return;
|
||||
}
|
||||
|
||||
let newPrompt = negativePrompt.slice(0, caret);
|
||||
|
||||
if (newPrompt[newPrompt.length - 1] !== '<') {
|
||||
newPrompt += '<';
|
||||
}
|
||||
|
||||
newPrompt += `${v}>`;
|
||||
|
||||
// we insert the cursor after the `>`
|
||||
const finalCaretPos = newPrompt.length;
|
||||
|
||||
newPrompt += negativePrompt.slice(caret);
|
||||
|
||||
// must flush dom updates else selection gets reset
|
||||
flushSync(() => {
|
||||
dispatch(setNegativePrompt(newPrompt));
|
||||
});
|
||||
|
||||
// set the caret position to just after the TI trigger promptRef.current.selectionStart = finalCaretPos;
|
||||
promptRef.current.selectionEnd = finalCaretPos;
|
||||
onClose();
|
||||
},
|
||||
[dispatch, onClose, negativePrompt]
|
||||
);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<ParamEmbeddingPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onSelect={handleSelectEmbedding}
|
||||
>
|
||||
<IAITextarea
|
||||
id="negativePrompt"
|
||||
name="negativePrompt"
|
||||
ref={promptRef}
|
||||
value={negativePrompt}
|
||||
onChange={(e) => dispatch(setNegativePrompt(e.target.value))}
|
||||
placeholder={t('parameters.negativePromptPlaceholder')}
|
||||
onChange={handleChangePrompt}
|
||||
onKeyDown={handleKeyDown}
|
||||
resize="vertical"
|
||||
fontSize="sm"
|
||||
minH={16}
|
||||
/>
|
||||
</ParamEmbeddingPopover>
|
||||
{!isOpen && (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
insetInlineEnd: 0,
|
||||
}}
|
||||
>
|
||||
<AddEmbeddingButton onClick={onOpen} />
|
||||
</Box>
|
||||
)}
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { Box, FormControl } from '@chakra-ui/react';
|
||||
import { Box, FormControl, useDisclosure } from '@chakra-ui/react';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react';
|
||||
@ -11,12 +11,15 @@ import {
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { userInvoked } from 'app/store/actions';
|
||||
import IAITextarea from 'common/components/IAITextarea';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton';
|
||||
import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { flushSync } from 'react-dom';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const promptInputSelector = createSelector(
|
||||
[(state: RootState) => state.generation, activeTabNameSelector],
|
||||
@ -40,14 +43,15 @@ const ParamPositiveConditioning = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { prompt, activeTabName } = useAppSelector(promptInputSelector);
|
||||
const isReady = useIsReadyToInvoke();
|
||||
|
||||
const promptRef = useRef<HTMLTextAreaElement>(null);
|
||||
|
||||
const { isOpen, onClose, onOpen } = useDisclosure();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleChangePrompt = (e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
const handleChangePrompt = useCallback(
|
||||
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
dispatch(setPositivePrompt(e.target.value));
|
||||
};
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
'alt+a',
|
||||
@ -57,6 +61,45 @@ const ParamPositiveConditioning = () => {
|
||||
[]
|
||||
);
|
||||
|
||||
const handleSelectEmbedding = useCallback(
|
||||
(v: string) => {
|
||||
if (!promptRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
// this is where we insert the TI trigger
|
||||
const caret = promptRef.current.selectionStart;
|
||||
|
||||
if (caret === undefined) {
|
||||
return;
|
||||
}
|
||||
|
||||
let newPrompt = prompt.slice(0, caret);
|
||||
|
||||
if (newPrompt[newPrompt.length - 1] !== '<') {
|
||||
newPrompt += '<';
|
||||
}
|
||||
|
||||
newPrompt += `${v}>`;
|
||||
|
||||
// we insert the cursor after the `>`
|
||||
const finalCaretPos = newPrompt.length;
|
||||
|
||||
newPrompt += prompt.slice(caret);
|
||||
|
||||
// must flush dom updates else selection gets reset
|
||||
flushSync(() => {
|
||||
dispatch(setPositivePrompt(newPrompt));
|
||||
});
|
||||
|
||||
// set the caret position to just after the TI trigger
|
||||
promptRef.current.selectionStart = finalCaretPos;
|
||||
promptRef.current.selectionEnd = finalCaretPos;
|
||||
onClose();
|
||||
},
|
||||
[dispatch, onClose, prompt]
|
||||
);
|
||||
|
||||
const handleKeyDown = useCallback(
|
||||
(e: KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
if (e.key === 'Enter' && e.shiftKey === false && isReady) {
|
||||
@ -64,25 +107,50 @@ const ParamPositiveConditioning = () => {
|
||||
dispatch(clampSymmetrySteps());
|
||||
dispatch(userInvoked(activeTabName));
|
||||
}
|
||||
if (e.key === '<') {
|
||||
onOpen();
|
||||
}
|
||||
},
|
||||
[dispatch, activeTabName, isReady]
|
||||
[isReady, dispatch, activeTabName, onOpen]
|
||||
);
|
||||
|
||||
// const handleSelect = (e: MouseEvent<HTMLTextAreaElement>) => {
|
||||
// const target = e.target as HTMLTextAreaElement;
|
||||
// setCaret({ start: target.selectionStart, end: target.selectionEnd });
|
||||
// };
|
||||
|
||||
return (
|
||||
<Box>
|
||||
<FormControl>
|
||||
<ParamEmbeddingPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onSelect={handleSelectEmbedding}
|
||||
>
|
||||
<IAITextarea
|
||||
id="prompt"
|
||||
name="prompt"
|
||||
placeholder={t('parameters.positivePromptPlaceholder')}
|
||||
ref={promptRef}
|
||||
value={prompt}
|
||||
placeholder={t('parameters.positivePromptPlaceholder')}
|
||||
onChange={handleChangePrompt}
|
||||
onKeyDown={handleKeyDown}
|
||||
resize="vertical"
|
||||
ref={promptRef}
|
||||
minH={32}
|
||||
/>
|
||||
</ParamEmbeddingPopover>
|
||||
</FormControl>
|
||||
{!isOpen && (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: 6,
|
||||
insetInlineEnd: 0,
|
||||
}}
|
||||
>
|
||||
<AddEmbeddingButton onClick={onOpen} />
|
||||
</Box>
|
||||
)}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
@ -1,4 +1,3 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIIconButton, {
|
||||
IAIIconButtonProps,
|
||||
@ -25,9 +24,9 @@ const PinParametersPanelButton = (props: PinParametersPanelButtonProps) => {
|
||||
};
|
||||
|
||||
return (
|
||||
<Tooltip label={t('common.pinOptionsPanel')}>
|
||||
<IAIIconButton
|
||||
{...props}
|
||||
tooltip={t('common.pinOptionsPanel')}
|
||||
aria-label={t('common.pinOptionsPanel')}
|
||||
onClick={handleClickPinOptionsPanel}
|
||||
icon={shouldPinParametersPanel ? <BsPinAngleFill /> : <BsPinAngle />}
|
||||
@ -44,7 +43,6 @@ const PinParametersPanelButton = (props: PinParametersPanelButtonProps) => {
|
||||
...sx,
|
||||
}}
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
||||
import { setActiveTabReducer } from './extraReducers';
|
||||
import { InvokeTabName } from './tabMap';
|
||||
import { AddNewModelType, UIState } from './uiTypes';
|
||||
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
||||
|
||||
export const initialUIState: UIState = {
|
||||
activeTab: 0,
|
||||
@ -19,6 +19,7 @@ export const initialUIState: UIState = {
|
||||
shouldShowGallery: true,
|
||||
shouldHidePreview: false,
|
||||
shouldShowProgressInViewer: true,
|
||||
shouldShowEmbeddingPicker: false,
|
||||
favoriteSchedulers: [],
|
||||
};
|
||||
|
||||
@ -96,6 +97,9 @@ export const uiSlice = createSlice({
|
||||
) => {
|
||||
state.favoriteSchedulers = action.payload;
|
||||
},
|
||||
toggleEmbeddingPicker: (state) => {
|
||||
state.shouldShowEmbeddingPicker = !state.shouldShowEmbeddingPicker;
|
||||
},
|
||||
},
|
||||
extraReducers(builder) {
|
||||
builder.addCase(initialImageChanged, (state) => {
|
||||
@ -122,6 +126,7 @@ export const {
|
||||
toggleGalleryPanel,
|
||||
setShouldShowProgressInViewer,
|
||||
favoriteSchedulersChanged,
|
||||
toggleEmbeddingPicker,
|
||||
} = uiSlice.actions;
|
||||
|
||||
export default uiSlice.reducer;
|
||||
|
@ -27,5 +27,6 @@ export interface UIState {
|
||||
shouldPinGallery: boolean;
|
||||
shouldShowGallery: boolean;
|
||||
shouldShowProgressInViewer: boolean;
|
||||
shouldShowEmbeddingPicker: boolean;
|
||||
favoriteSchedulers: SchedulerParam[];
|
||||
}
|
||||
|
@ -1,18 +1,18 @@
|
||||
import { Middleware, MiddlewareAPI } from '@reduxjs/toolkit';
|
||||
import { io, Socket } from 'socket.io-client';
|
||||
import { Socket, io } from 'socket.io-client';
|
||||
|
||||
import { AppThunkDispatch, RootState } from 'app/store/store';
|
||||
import { getTimestamp } from 'common/util/getTimestamp';
|
||||
import { sessionCreated } from 'services/api/thunks/session';
|
||||
import {
|
||||
ClientToServerEvents,
|
||||
ServerToClientEvents,
|
||||
} from 'services/events/types';
|
||||
import { socketSubscribed, socketUnsubscribed } from './actions';
|
||||
import { AppThunkDispatch, RootState } from 'app/store/store';
|
||||
import { getTimestamp } from 'common/util/getTimestamp';
|
||||
import { sessionCreated } from 'services/api/thunks/session';
|
||||
// import { OpenAPI } from 'services/api/types';
|
||||
import { setEventListeners } from 'services/events/util/setEventListeners';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { $authToken, $baseUrl } from 'services/api/client';
|
||||
import { setEventListeners } from 'services/events/util/setEventListeners';
|
||||
|
||||
const socketioLog = log.child({ namespace: 'socketio' });
|
||||
|
||||
@ -88,7 +88,7 @@ export const socketMiddleware = () => {
|
||||
socketSubscribed({
|
||||
sessionId: sessionId,
|
||||
timestamp: getTimestamp(),
|
||||
boardId: getState().boards.selectedBoardId,
|
||||
boardId: getState().gallery.selectedBoardId,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user