Merge branch 'main' into lstein/model-manager-router-api

This commit is contained in:
Lincoln Stein 2023-07-06 13:20:36 -04:00 committed by GitHub
commit 581be42c75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 474 additions and 114 deletions

View File

@ -430,13 +430,13 @@ to allow InvokeAI to download restricted styles & subjects from the "Concept Lib
max_height=len(PRECISION_CHOICES) + 1, max_height=len(PRECISION_CHOICES) + 1,
scroll_exit=True, scroll_exit=True,
) )
self.max_loaded_models = self.add_widget_intelligent( self.max_cache_size = self.add_widget_intelligent(
IntTitleSlider, IntTitleSlider,
name="Number of models to cache in CPU memory (each will use 2-4 GB!)", name="Size of the RAM cache used for fast model switching (GB)",
value=old_opts.max_loaded_models, value=old_opts.max_cache_size,
out_of=10, out_of=20,
lowest=1, lowest=3,
begin_entry_at=4, begin_entry_at=6,
scroll_exit=True, scroll_exit=True,
) )
self.nextrely += 1 self.nextrely += 1
@ -539,7 +539,7 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
"outdir", "outdir",
"nsfw_checker", "nsfw_checker",
"free_gpu_mem", "free_gpu_mem",
"max_loaded_models", "max_cache_size",
"xformers_enabled", "xformers_enabled",
"always_use_cpu", "always_use_cpu",
]: ]:
@ -555,9 +555,6 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
new_opts.license_acceptance = self.license_acceptance.value new_opts.license_acceptance = self.license_acceptance.value
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]] new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
# widget library workaround to make max_loaded_models an int rather than a float
new_opts.max_loaded_models = int(new_opts.max_loaded_models)
return new_opts return new_opts

View File

@ -196,8 +196,11 @@ class ModelInstall(object):
models_installed.update({str(path):self._install_path(path)}) models_installed.update({str(path):self._install_path(path)})
# folders style or similar # folders style or similar
elif path.is_dir() and any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]): elif path.is_dir() and any([(path/x).exists() for x in \
models_installed.update({str(path): self._install_path(path)}) {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
]
):
models_installed.update(self._install_path(path))
# recursive scan # recursive scan
elif path.is_dir(): elif path.is_dir():

View File

@ -4,6 +4,7 @@ import copy
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional, Dict, Tuple, Any, Union, List from typing import Optional, Dict, Tuple, Any, Union, List
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union, List
import torch import torch
from compel.embeddings_provider import BaseTextualInversionManager from compel.embeddings_provider import BaseTextualInversionManager

View File

@ -8,7 +8,7 @@ The cache returns context manager generators designed to load the
model into the GPU within the context, and unload outside the model into the GPU within the context, and unload outside the
context. Use like this: context. Use like this:
cache = ModelCache(max_models_cached=6) cache = ModelCache(max_cache_size=7.5)
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1, with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
cache.get_model('stabilityai/stable-diffusion-2') as SD2: cache.get_model('stabilityai/stable-diffusion-2') as SD2:
do_something_in_GPU(SD1,SD2) do_something_in_GPU(SD1,SD2)
@ -91,7 +91,7 @@ class ModelCache(object):
logger: types.ModuleType = logger logger: types.ModuleType = logger
): ):
''' '''
:param max_models: Maximum number of models to cache in CPU RAM [4] :param max_cache_size: Maximum size of the RAM cache [6.0 GB]
:param execution_device: Torch device to load active model into [torch.device('cuda')] :param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')] :param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16] :param precision: Precision for loaded models [torch.float16]
@ -126,16 +126,6 @@ class ModelCache(object):
key += f":{submodel_type}" key += f":{submodel_type}"
return key return key
#def get_model(
# self,
# repo_id_or_path: Union[str, Path],
# model_type: ModelType = ModelType.Diffusers,
# subfolder: Path = None,
# submodel: ModelType = None,
# revision: str = None,
# attach_model_part: Tuple[ModelType, str] = (None, None),
# gpu_load: bool = True,
#) -> ModelLocker: # ?? what does it return
def _get_model_info( def _get_model_info(
self, self,
model_path: str, model_path: str,

View File

@ -852,7 +852,7 @@ class ModelManager(object):
if path in known_paths or path.parent in scanned_dirs: if path in known_paths or path.parent in scanned_dirs:
scanned_dirs.add(path) scanned_dirs.add(path)
continue continue
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]): if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
new_models_found.update(installer.heuristic_import(path)) new_models_found.update(installer.heuristic_import(path))
scanned_dirs.add(path) scanned_dirs.add(path)
@ -861,7 +861,8 @@ class ModelManager(object):
if path in known_paths or path.parent in scanned_dirs: if path in known_paths or path.parent in scanned_dirs:
continue continue
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}: if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
new_models_found.update(installer.heuristic_import(path)) import_result = installer.heuristic_import(path)
new_models_found.update(import_result)
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models') self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
installed.update(new_models_found) installed.update(new_models_found)

View File

@ -78,7 +78,6 @@ class ModelProbe(object):
format_type = 'diffusers' if model_path.is_dir() else 'checkpoint' format_type = 'diffusers' if model_path.is_dir() else 'checkpoint'
else: else:
format_type = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint' format_type = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
model_info = None model_info = None
try: try:
model_type = cls.get_model_type_from_folder(model_path, model) \ model_type = cls.get_model_type_from_folder(model_path, model) \
@ -105,7 +104,7 @@ class ModelProbe(object):
) else 512, ) else 512,
) )
except Exception: except Exception:
return None raise
return model_info return model_info
@ -127,6 +126,8 @@ class ModelProbe(object):
return ModelType.Vae return ModelType.Vae
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}): elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
return ModelType.Lora return ModelType.Lora
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
return ModelType.Lora
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}): elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
return ModelType.ControlNet return ModelType.ControlNet
elif key in {"emb_params", "string_to_param"}: elif key in {"emb_params", "string_to_param"}:
@ -137,7 +138,7 @@ class ModelProbe(object):
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()): if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
return ModelType.TextualInversion return ModelType.TextualInversion
raise ValueError("Unable to determine model type") raise ValueError(f"Unable to determine model type for {model_path}")
@classmethod @classmethod
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType: def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
@ -167,7 +168,7 @@ class ModelProbe(object):
return type return type
# give up # give up
raise ValueError("Unable to determine model type") raise ValueError("Unable to determine model type for {folder_path}")
@classmethod @classmethod
def _scan_and_load_checkpoint(cls,model_path: Path)->dict: def _scan_and_load_checkpoint(cls,model_path: Path)->dict:

View File

@ -678,9 +678,8 @@ def select_and_download_models(opt: Namespace):
# this is where the TUI is called # this is where the TUI is called
else: else:
# needed because the torch library is loaded, even though we don't use it # needed to support the probe() method running under a subprocess
# currently commented out because it has started generating errors (?) torch.multiprocessing.set_start_method("spawn")
# torch.multiprocessing.set_start_method("spawn")
# the third argument is needed in the Windows 11 environment in # the third argument is needed in the Windows 11 environment in
# order to launch and resize a console window running this program # order to launch and resize a console window running this program

View File

@ -1,6 +1,5 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { appSocketConnected, socketConnected } from 'services/events/actions'; import { appSocketConnected, socketConnected } from 'services/events/actions';
import { receivedPageOfImages } from 'services/api/thunks/image';
import { receivedOpenAPISchema } from 'services/api/thunks/schema'; import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { startAppListening } from '../..'; import { startAppListening } from '../..';
@ -14,19 +13,10 @@ export const addSocketConnectedEventListener = () => {
moduleLog.debug({ timestamp }, 'Connected'); moduleLog.debug({ timestamp }, 'Connected');
const { nodes, config, gallery } = getState(); const { nodes, config } = getState();
const { disabledTabs } = config; const { disabledTabs } = config;
if (!gallery.ids.length) {
dispatch(
receivedPageOfImages({
categories: ['general'],
is_intermediate: false,
})
);
}
if (!nodes.schema && !disabledTabs.includes('nodes')) { if (!nodes.schema && !disabledTabs.includes('nodes')) {
dispatch(receivedOpenAPISchema()); dispatch(receivedOpenAPISchema());
} }

View File

@ -1,15 +1,16 @@
import { Tooltip, useColorMode, useToken } from '@chakra-ui/react'; import { Tooltip, useColorMode, useToken } from '@chakra-ui/react';
import { MultiSelect, MultiSelectProps } from '@mantine/core'; import { MultiSelect, MultiSelectProps } from '@mantine/core';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { memo } from 'react'; import { RefObject, memo } from 'react';
import { mode } from 'theme/util/mode'; import { mode } from 'theme/util/mode';
type IAIMultiSelectProps = MultiSelectProps & { type IAIMultiSelectProps = MultiSelectProps & {
tooltip?: string; tooltip?: string;
inputRef?: RefObject<HTMLInputElement>;
}; };
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
const { searchable = true, tooltip, ...rest } = props; const { searchable = true, tooltip, inputRef, ...rest } = props;
const { const {
base50, base50,
base100, base100,
@ -33,6 +34,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
return ( return (
<Tooltip label={tooltip} placement="top" hasArrow> <Tooltip label={tooltip} placement="top" hasArrow>
<MultiSelect <MultiSelect
ref={inputRef}
searchable={searchable} searchable={searchable}
styles={() => ({ styles={() => ({
label: { label: {

View File

@ -6,10 +6,15 @@ import { validateSeedWeights } from 'common/util/seedWeightPairs';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import { systemSelector } from 'features/system/store/systemSelectors'; import { systemSelector } from 'features/system/store/systemSelectors';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import {
modelsApi,
useGetMainModelsQuery,
} from '../../services/api/endpoints/models';
const readinessSelector = createSelector( const readinessSelector = createSelector(
[stateSelector, activeTabNameSelector], [stateSelector, activeTabNameSelector],
({ generation, system, batch }, activeTabName) => { (state, activeTabName) => {
const { generation, system, batch } = state;
const { shouldGenerateVariations, seedWeights, initialImage, seed } = const { shouldGenerateVariations, seedWeights, initialImage, seed } =
generation; generation;
@ -32,6 +37,13 @@ const readinessSelector = createSelector(
reasonsWhyNotReady.push('No initial image selected'); reasonsWhyNotReady.push('No initial image selected');
} }
const { isSuccess: mainModelsSuccessfullyLoaded } =
modelsApi.endpoints.getMainModels.select()(state);
if (!mainModelsSuccessfullyLoaded) {
isReady = false;
reasonsWhyNotReady.push('Models are not loaded');
}
// TODO: job queue // TODO: job queue
// Cannot generate if already processing an image // Cannot generate if already processing an image
if (isProcessing) { if (isProcessing) {

View File

@ -0,0 +1,33 @@
import IAIIconButton from 'common/components/IAIIconButton';
import { memo } from 'react';
import { BiCode } from 'react-icons/bi';
type Props = {
onClick: () => void;
};
const AddEmbeddingButton = (props: Props) => {
const { onClick } = props;
return (
<IAIIconButton
size="sm"
aria-label="Add Embedding"
tooltip="Add Embedding"
icon={<BiCode />}
sx={{
p: 2,
color: 'base.700',
_hover: {
color: 'base.550',
},
_active: {
color: 'base.500',
},
}}
variant="link"
onClick={onClick}
/>
);
};
export default memo(AddEmbeddingButton);

View File

@ -0,0 +1,151 @@
import {
Flex,
Popover,
PopoverBody,
PopoverContent,
PopoverTrigger,
Text,
} from '@chakra-ui/react';
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
import { forEach } from 'lodash-es';
import {
PropsWithChildren,
forwardRef,
useCallback,
useMemo,
useRef,
} from 'react';
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
type EmbeddingSelectItem = {
label: string;
value: string;
description?: string;
};
type Props = PropsWithChildren & {
onSelect: (v: string) => void;
isOpen: boolean;
onClose: () => void;
};
const ParamEmbeddingPopover = (props: Props) => {
const { onSelect, isOpen, onClose, children } = props;
const { data: embeddingQueryData } = useGetTextualInversionModelsQuery();
const inputRef = useRef<HTMLInputElement>(null);
const data = useMemo(() => {
if (!embeddingQueryData) {
return [];
}
const data: EmbeddingSelectItem[] = [];
forEach(embeddingQueryData.entities, (embedding, _) => {
if (!embedding) return;
data.push({
value: embedding.name,
label: embedding.name,
description: embedding.description,
});
});
return data;
}, [embeddingQueryData]);
const handleChange = useCallback(
(v: string[]) => {
if (v.length === 0) {
return;
}
onSelect(v[0]);
},
[onSelect]
);
return (
<Popover
initialFocusRef={inputRef}
isOpen={isOpen}
onClose={onClose}
placement="bottom"
openDelay={0}
closeDelay={0}
closeOnBlur={true}
returnFocusOnClose={true}
>
<PopoverTrigger>{children}</PopoverTrigger>
<PopoverContent
sx={{
p: 0,
top: -1,
shadow: 'dark-lg',
borderColor: 'accent.300',
borderWidth: '2px',
borderStyle: 'solid',
_dark: { borderColor: 'accent.400' },
}}
>
<PopoverBody
sx={{ p: 0, w: `calc(${PARAMETERS_PANEL_WIDTH} - 2rem )` }}
>
{data.length === 0 ? (
<Flex sx={{ justifyContent: 'center', p: 2 }}>
<Text
sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}
>
No Embeddings Loaded
</Text>
</Flex>
) : (
<IAIMantineMultiSelect
inputRef={inputRef}
placeholder={'Add Embedding'}
value={[]}
data={data}
maxDropdownHeight={400}
nothingFound="No Matching Embeddings"
itemComponent={SelectItem}
disabled={data.length === 0}
filter={(value, selected, item: EmbeddingSelectItem) =>
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim())
}
onChange={handleChange}
/>
)}
</PopoverBody>
</PopoverContent>
</Popover>
);
};
export default ParamEmbeddingPopover;
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
value: string;
label: string;
description?: string;
}
const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
({ label, description, ...others }: ItemProps, ref) => {
return (
<div ref={ref} {...others}>
<div>
<Text>{label}</Text>
{description && (
<Text size="xs" color="base.600">
{description}
</Text>
)}
</div>
</div>
);
}
);
SelectItem.displayName = 'SelectItem';

View File

@ -23,6 +23,7 @@ export const makeSelector = (image_name: string) =>
({ gallery }) => { ({ gallery }) => {
const isSelected = gallery.selection.includes(image_name); const isSelected = gallery.selection.includes(image_name);
const selectionCount = gallery.selection.length; const selectionCount = gallery.selection.length;
return { return {
isSelected, isSelected,
selectionCount, selectionCount,
@ -117,7 +118,7 @@ const GalleryImage = (props: HoverableImageProps) => {
resetIcon={<FaTrash />} resetIcon={<FaTrash />}
resetTooltip="Delete image" resetTooltip="Delete image"
imageSx={{ w: 'full', h: 'full' }} imageSx={{ w: 'full', h: 'full' }}
withResetIcon // withResetIcon // removed bc it's too easy to accidentally delete images
isDropDisabled={true} isDropDisabled={true}
isUploadDisabled={true} isUploadDisabled={true}
/> />

View File

@ -182,6 +182,15 @@ const ImageGalleryContent = () => {
return () => osInstance()?.destroy(); return () => osInstance()?.destroy();
}, [scroller, initialize, osInstance]); }, [scroller, initialize, osInstance]);
useEffect(() => {
dispatch(
receivedPageOfImages({
categories: ['general'],
is_intermediate: false,
})
);
}, [dispatch]);
const handleClickImagesCategory = useCallback(() => { const handleClickImagesCategory = useCallback(() => {
dispatch(imageCategoriesChanged(IMAGE_CATEGORIES)); dispatch(imageCategoriesChanged(IMAGE_CATEGORIES));
dispatch(setGalleryView('images')); dispatch(setGalleryView('images'));

View File

@ -4,7 +4,12 @@ import IAIIconButton from 'common/components/IAIIconButton';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { FaTrash } from 'react-icons/fa'; import { FaTrash } from 'react-icons/fa';
import { Lora, loraRemoved, loraWeightChanged } from '../store/loraSlice'; import {
Lora,
loraRemoved,
loraWeightChanged,
loraWeightReset,
} from '../store/loraSlice';
type Props = { type Props = {
lora: Lora; lora: Lora;
@ -22,7 +27,7 @@ const ParamLora = (props: Props) => {
); );
const handleReset = useCallback(() => { const handleReset = useCallback(() => {
dispatch(loraWeightChanged({ id: lora.id, weight: 1 })); dispatch(loraWeightReset(lora.id));
}, [dispatch, lora.id]); }, [dispatch, lora.id]);
const handleRemoveLora = useCallback(() => { const handleRemoveLora = useCallback(() => {

View File

@ -1,4 +1,4 @@
import { Text } from '@chakra-ui/react'; import { Flex, Text } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@ -61,6 +61,16 @@ const ParamLoraSelect = () => {
[dispatch, lorasQueryData?.entities] [dispatch, lorasQueryData?.entities]
); );
if (lorasQueryData?.ids.length === 0) {
return (
<Flex sx={{ justifyContent: 'center', p: 2 }}>
<Text sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}>
No LoRAs Loaded
</Text>
</Flex>
);
}
return ( return (
<IAIMantineMultiSelect <IAIMantineMultiSelect
placeholder={data.length === 0 ? 'All LoRAs added' : 'Add LoRA'} placeholder={data.length === 0 ? 'All LoRAs added' : 'Add LoRA'}

View File

@ -8,7 +8,7 @@ export type Lora = {
}; };
export const defaultLoRAConfig: Omit<Lora, 'id' | 'name'> = { export const defaultLoRAConfig: Omit<Lora, 'id' | 'name'> = {
weight: 1, weight: 0.75,
}; };
export type LoraState = { export type LoraState = {
@ -38,9 +38,14 @@ export const loraSlice = createSlice({
const { id, weight } = action.payload; const { id, weight } = action.payload;
state.loras[id].weight = weight; state.loras[id].weight = weight;
}, },
loraWeightReset: (state, action: PayloadAction<string>) => {
const id = action.payload;
state.loras[id].weight = defaultLoRAConfig.weight;
},
}, },
}); });
export const { loraAdded, loraRemoved, loraWeightChanged } = loraSlice.actions; export const { loraAdded, loraRemoved, loraWeightChanged, loraWeightReset } =
loraSlice.actions;
export default loraSlice.reducer; export default loraSlice.reducer;

View File

@ -1,29 +1,107 @@
import { FormControl } from '@chakra-ui/react'; import { Box, FormControl, useDisclosure } from '@chakra-ui/react';
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAITextarea from 'common/components/IAITextarea'; import IAITextarea from 'common/components/IAITextarea';
import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton';
import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover';
import { setNegativePrompt } from 'features/parameters/store/generationSlice'; import { setNegativePrompt } from 'features/parameters/store/generationSlice';
import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react';
import { flushSync } from 'react-dom';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
const ParamNegativeConditioning = () => { const ParamNegativeConditioning = () => {
const negativePrompt = useAppSelector( const negativePrompt = useAppSelector(
(state: RootState) => state.generation.negativePrompt (state: RootState) => state.generation.negativePrompt
); );
const promptRef = useRef<HTMLTextAreaElement>(null);
const { isOpen, onClose, onOpen } = useDisclosure();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const handleChangePrompt = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(setNegativePrompt(e.target.value));
},
[dispatch]
);
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLTextAreaElement>) => {
if (e.key === '<') {
onOpen();
}
},
[onOpen]
);
const handleSelectEmbedding = useCallback(
(v: string) => {
if (!promptRef.current) {
return;
}
// this is where we insert the TI trigger
const caret = promptRef.current.selectionStart;
if (caret === undefined) {
return;
}
let newPrompt = negativePrompt.slice(0, caret);
if (newPrompt[newPrompt.length - 1] !== '<') {
newPrompt += '<';
}
newPrompt += `${v}>`;
// we insert the cursor after the `>`
const finalCaretPos = newPrompt.length;
newPrompt += negativePrompt.slice(caret);
// must flush dom updates else selection gets reset
flushSync(() => {
dispatch(setNegativePrompt(newPrompt));
});
// set the caret position to just after the TI trigger promptRef.current.selectionStart = finalCaretPos;
promptRef.current.selectionEnd = finalCaretPos;
onClose();
},
[dispatch, onClose, negativePrompt]
);
return ( return (
<FormControl> <FormControl>
<ParamEmbeddingPopover
isOpen={isOpen}
onClose={onClose}
onSelect={handleSelectEmbedding}
>
<IAITextarea <IAITextarea
id="negativePrompt" id="negativePrompt"
name="negativePrompt" name="negativePrompt"
ref={promptRef}
value={negativePrompt} value={negativePrompt}
onChange={(e) => dispatch(setNegativePrompt(e.target.value))}
placeholder={t('parameters.negativePromptPlaceholder')} placeholder={t('parameters.negativePromptPlaceholder')}
onChange={handleChangePrompt}
onKeyDown={handleKeyDown}
resize="vertical"
fontSize="sm" fontSize="sm"
minH={16} minH={16}
/> />
</ParamEmbeddingPopover>
{!isOpen && (
<Box
sx={{
position: 'absolute',
top: 0,
insetInlineEnd: 0,
}}
>
<AddEmbeddingButton onClick={onOpen} />
</Box>
)}
</FormControl> </FormControl>
); );
}; };

View File

@ -1,4 +1,4 @@
import { Box, FormControl } from '@chakra-ui/react'; import { Box, FormControl, useDisclosure } from '@chakra-ui/react';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react'; import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react';
@ -11,12 +11,15 @@ import {
} from 'features/parameters/store/generationSlice'; } from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash-es';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { userInvoked } from 'app/store/actions'; import { userInvoked } from 'app/store/actions';
import IAITextarea from 'common/components/IAITextarea'; import IAITextarea from 'common/components/IAITextarea';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton';
import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover';
import { isEqual } from 'lodash-es';
import { flushSync } from 'react-dom';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
const promptInputSelector = createSelector( const promptInputSelector = createSelector(
[(state: RootState) => state.generation, activeTabNameSelector], [(state: RootState) => state.generation, activeTabNameSelector],
@ -40,14 +43,15 @@ const ParamPositiveConditioning = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { prompt, activeTabName } = useAppSelector(promptInputSelector); const { prompt, activeTabName } = useAppSelector(promptInputSelector);
const isReady = useIsReadyToInvoke(); const isReady = useIsReadyToInvoke();
const promptRef = useRef<HTMLTextAreaElement>(null); const promptRef = useRef<HTMLTextAreaElement>(null);
const { isOpen, onClose, onOpen } = useDisclosure();
const { t } = useTranslation(); const { t } = useTranslation();
const handleChangePrompt = useCallback(
const handleChangePrompt = (e: ChangeEvent<HTMLTextAreaElement>) => { (e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(setPositivePrompt(e.target.value)); dispatch(setPositivePrompt(e.target.value));
}; },
[dispatch]
);
useHotkeys( useHotkeys(
'alt+a', 'alt+a',
@ -57,6 +61,45 @@ const ParamPositiveConditioning = () => {
[] []
); );
const handleSelectEmbedding = useCallback(
(v: string) => {
if (!promptRef.current) {
return;
}
// this is where we insert the TI trigger
const caret = promptRef.current.selectionStart;
if (caret === undefined) {
return;
}
let newPrompt = prompt.slice(0, caret);
if (newPrompt[newPrompt.length - 1] !== '<') {
newPrompt += '<';
}
newPrompt += `${v}>`;
// we insert the cursor after the `>`
const finalCaretPos = newPrompt.length;
newPrompt += prompt.slice(caret);
// must flush dom updates else selection gets reset
flushSync(() => {
dispatch(setPositivePrompt(newPrompt));
});
// set the caret position to just after the TI trigger
promptRef.current.selectionStart = finalCaretPos;
promptRef.current.selectionEnd = finalCaretPos;
onClose();
},
[dispatch, onClose, prompt]
);
const handleKeyDown = useCallback( const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLTextAreaElement>) => { (e: KeyboardEvent<HTMLTextAreaElement>) => {
if (e.key === 'Enter' && e.shiftKey === false && isReady) { if (e.key === 'Enter' && e.shiftKey === false && isReady) {
@ -64,25 +107,50 @@ const ParamPositiveConditioning = () => {
dispatch(clampSymmetrySteps()); dispatch(clampSymmetrySteps());
dispatch(userInvoked(activeTabName)); dispatch(userInvoked(activeTabName));
} }
if (e.key === '<') {
onOpen();
}
}, },
[dispatch, activeTabName, isReady] [isReady, dispatch, activeTabName, onOpen]
); );
// const handleSelect = (e: MouseEvent<HTMLTextAreaElement>) => {
// const target = e.target as HTMLTextAreaElement;
// setCaret({ start: target.selectionStart, end: target.selectionEnd });
// };
return ( return (
<Box> <Box>
<FormControl> <FormControl>
<ParamEmbeddingPopover
isOpen={isOpen}
onClose={onClose}
onSelect={handleSelectEmbedding}
>
<IAITextarea <IAITextarea
id="prompt" id="prompt"
name="prompt" name="prompt"
placeholder={t('parameters.positivePromptPlaceholder')} ref={promptRef}
value={prompt} value={prompt}
placeholder={t('parameters.positivePromptPlaceholder')}
onChange={handleChangePrompt} onChange={handleChangePrompt}
onKeyDown={handleKeyDown} onKeyDown={handleKeyDown}
resize="vertical" resize="vertical"
ref={promptRef}
minH={32} minH={32}
/> />
</ParamEmbeddingPopover>
</FormControl> </FormControl>
{!isOpen && (
<Box
sx={{
position: 'absolute',
top: 6,
insetInlineEnd: 0,
}}
>
<AddEmbeddingButton onClick={onOpen} />
</Box>
)}
</Box> </Box>
); );
}; };

View File

@ -1,4 +1,3 @@
import { Tooltip } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton, { import IAIIconButton, {
IAIIconButtonProps, IAIIconButtonProps,
@ -25,9 +24,9 @@ const PinParametersPanelButton = (props: PinParametersPanelButtonProps) => {
}; };
return ( return (
<Tooltip label={t('common.pinOptionsPanel')}>
<IAIIconButton <IAIIconButton
{...props} {...props}
tooltip={t('common.pinOptionsPanel')}
aria-label={t('common.pinOptionsPanel')} aria-label={t('common.pinOptionsPanel')}
onClick={handleClickPinOptionsPanel} onClick={handleClickPinOptionsPanel}
icon={shouldPinParametersPanel ? <BsPinAngleFill /> : <BsPinAngle />} icon={shouldPinParametersPanel ? <BsPinAngleFill /> : <BsPinAngle />}
@ -44,7 +43,6 @@ const PinParametersPanelButton = (props: PinParametersPanelButtonProps) => {
...sx, ...sx,
}} }}
/> />
</Tooltip>
); );
}; };

View File

@ -1,10 +1,10 @@
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import { initialImageChanged } from 'features/parameters/store/generationSlice'; import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
import { setActiveTabReducer } from './extraReducers'; import { setActiveTabReducer } from './extraReducers';
import { InvokeTabName } from './tabMap'; import { InvokeTabName } from './tabMap';
import { AddNewModelType, UIState } from './uiTypes'; import { AddNewModelType, UIState } from './uiTypes';
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
export const initialUIState: UIState = { export const initialUIState: UIState = {
activeTab: 0, activeTab: 0,
@ -19,6 +19,7 @@ export const initialUIState: UIState = {
shouldShowGallery: true, shouldShowGallery: true,
shouldHidePreview: false, shouldHidePreview: false,
shouldShowProgressInViewer: true, shouldShowProgressInViewer: true,
shouldShowEmbeddingPicker: false,
favoriteSchedulers: [], favoriteSchedulers: [],
}; };
@ -96,6 +97,9 @@ export const uiSlice = createSlice({
) => { ) => {
state.favoriteSchedulers = action.payload; state.favoriteSchedulers = action.payload;
}, },
toggleEmbeddingPicker: (state) => {
state.shouldShowEmbeddingPicker = !state.shouldShowEmbeddingPicker;
},
}, },
extraReducers(builder) { extraReducers(builder) {
builder.addCase(initialImageChanged, (state) => { builder.addCase(initialImageChanged, (state) => {
@ -122,6 +126,7 @@ export const {
toggleGalleryPanel, toggleGalleryPanel,
setShouldShowProgressInViewer, setShouldShowProgressInViewer,
favoriteSchedulersChanged, favoriteSchedulersChanged,
toggleEmbeddingPicker,
} = uiSlice.actions; } = uiSlice.actions;
export default uiSlice.reducer; export default uiSlice.reducer;

View File

@ -27,5 +27,6 @@ export interface UIState {
shouldPinGallery: boolean; shouldPinGallery: boolean;
shouldShowGallery: boolean; shouldShowGallery: boolean;
shouldShowProgressInViewer: boolean; shouldShowProgressInViewer: boolean;
shouldShowEmbeddingPicker: boolean;
favoriteSchedulers: SchedulerParam[]; favoriteSchedulers: SchedulerParam[];
} }

View File

@ -1,18 +1,18 @@
import { Middleware, MiddlewareAPI } from '@reduxjs/toolkit'; import { Middleware, MiddlewareAPI } from '@reduxjs/toolkit';
import { io, Socket } from 'socket.io-client'; import { Socket, io } from 'socket.io-client';
import { AppThunkDispatch, RootState } from 'app/store/store';
import { getTimestamp } from 'common/util/getTimestamp';
import { sessionCreated } from 'services/api/thunks/session';
import { import {
ClientToServerEvents, ClientToServerEvents,
ServerToClientEvents, ServerToClientEvents,
} from 'services/events/types'; } from 'services/events/types';
import { socketSubscribed, socketUnsubscribed } from './actions'; import { socketSubscribed, socketUnsubscribed } from './actions';
import { AppThunkDispatch, RootState } from 'app/store/store';
import { getTimestamp } from 'common/util/getTimestamp';
import { sessionCreated } from 'services/api/thunks/session';
// import { OpenAPI } from 'services/api/types'; // import { OpenAPI } from 'services/api/types';
import { setEventListeners } from 'services/events/util/setEventListeners';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { $authToken, $baseUrl } from 'services/api/client'; import { $authToken, $baseUrl } from 'services/api/client';
import { setEventListeners } from 'services/events/util/setEventListeners';
const socketioLog = log.child({ namespace: 'socketio' }); const socketioLog = log.child({ namespace: 'socketio' });
@ -88,7 +88,7 @@ export const socketMiddleware = () => {
socketSubscribed({ socketSubscribed({
sessionId: sessionId, sessionId: sessionId,
timestamp: getTimestamp(), timestamp: getTimestamp(),
boardId: getState().boards.selectedBoardId, boardId: getState().gallery.selectedBoardId,
}) })
); );
} }