mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into lstein/configure-max-cache-size
This commit is contained in:
commit
b229fe19aa
@ -193,7 +193,10 @@ class ModelInstall(object):
|
|||||||
models_installed.update(self._install_path(path))
|
models_installed.update(self._install_path(path))
|
||||||
|
|
||||||
# folders style or similar
|
# folders style or similar
|
||||||
elif path.is_dir() and any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
elif path.is_dir() and any([(path/x).exists() for x in \
|
||||||
|
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
|
||||||
|
]
|
||||||
|
):
|
||||||
models_installed.update(self._install_path(path))
|
models_installed.update(self._install_path(path))
|
||||||
|
|
||||||
# recursive scan
|
# recursive scan
|
||||||
|
@ -3,15 +3,13 @@ from __future__ import annotations
|
|||||||
import copy
|
import copy
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple, Union, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from compel.embeddings_provider import BaseTextualInversionManager
|
from compel.embeddings_provider import BaseTextualInversionManager
|
||||||
from diffusers.models import UNet2DConditionModel
|
from diffusers.models import UNet2DConditionModel
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from torch.utils.hooks import RemovableHandle
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
from transformers import CLIPTextModel
|
|
||||||
|
|
||||||
|
|
||||||
class LoRALayerBase:
|
class LoRALayerBase:
|
||||||
#rank: Optional[int]
|
#rank: Optional[int]
|
||||||
@ -123,8 +121,8 @@ class LoRALayer(LoRALayerBase):
|
|||||||
|
|
||||||
def get_weight(self):
|
def get_weight(self):
|
||||||
if self.mid is not None:
|
if self.mid is not None:
|
||||||
up = self.up.reshape(up.shape[0], up.shape[1])
|
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||||
down = self.down.reshape(up.shape[0], up.shape[1])
|
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||||
else:
|
else:
|
||||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||||
@ -410,7 +408,7 @@ class LoRAModel: #(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
# TODO: diff/ia3/... format
|
# TODO: diff/ia3/... format
|
||||||
print(
|
print(
|
||||||
f">> Encountered unknown lora layer module in {self.name}: {layer_key}"
|
f">> Encountered unknown lora layer module in {model.name}: {layer_key}"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -785,7 +785,7 @@ class ModelManager(object):
|
|||||||
if path in known_paths or path.parent in scanned_dirs:
|
if path in known_paths or path.parent in scanned_dirs:
|
||||||
scanned_dirs.add(path)
|
scanned_dirs.add(path)
|
||||||
continue
|
continue
|
||||||
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
|
||||||
new_models_found.update(installer.heuristic_import(path))
|
new_models_found.update(installer.heuristic_import(path))
|
||||||
scanned_dirs.add(path)
|
scanned_dirs.add(path)
|
||||||
|
|
||||||
@ -794,7 +794,8 @@ class ModelManager(object):
|
|||||||
if path in known_paths or path.parent in scanned_dirs:
|
if path in known_paths or path.parent in scanned_dirs:
|
||||||
continue
|
continue
|
||||||
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
|
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
|
||||||
new_models_found.update(installer.heuristic_import(path))
|
import_result = installer.heuristic_import(path)
|
||||||
|
new_models_found.update(import_result)
|
||||||
|
|
||||||
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
|
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
|
||||||
installed.update(new_models_found)
|
installed.update(new_models_found)
|
||||||
|
@ -78,7 +78,6 @@ class ModelProbe(object):
|
|||||||
format_type = 'diffusers' if model_path.is_dir() else 'checkpoint'
|
format_type = 'diffusers' if model_path.is_dir() else 'checkpoint'
|
||||||
else:
|
else:
|
||||||
format_type = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
|
format_type = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
|
||||||
|
|
||||||
model_info = None
|
model_info = None
|
||||||
try:
|
try:
|
||||||
model_type = cls.get_model_type_from_folder(model_path, model) \
|
model_type = cls.get_model_type_from_folder(model_path, model) \
|
||||||
@ -105,7 +104,7 @@ class ModelProbe(object):
|
|||||||
) else 512,
|
) else 512,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
raise
|
||||||
|
|
||||||
return model_info
|
return model_info
|
||||||
|
|
||||||
@ -127,6 +126,8 @@ class ModelProbe(object):
|
|||||||
return ModelType.Vae
|
return ModelType.Vae
|
||||||
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||||
return ModelType.Lora
|
return ModelType.Lora
|
||||||
|
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
||||||
|
return ModelType.Lora
|
||||||
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
||||||
return ModelType.ControlNet
|
return ModelType.ControlNet
|
||||||
elif key in {"emb_params", "string_to_param"}:
|
elif key in {"emb_params", "string_to_param"}:
|
||||||
@ -137,7 +138,7 @@ class ModelProbe(object):
|
|||||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||||
return ModelType.TextualInversion
|
return ModelType.TextualInversion
|
||||||
|
|
||||||
raise ValueError("Unable to determine model type")
|
raise ValueError(f"Unable to determine model type for {model_path}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
|
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
|
||||||
@ -167,7 +168,7 @@ class ModelProbe(object):
|
|||||||
return type
|
return type
|
||||||
|
|
||||||
# give up
|
# give up
|
||||||
raise ValueError("Unable to determine model type")
|
raise ValueError("Unable to determine model type for {folder_path}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
|
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
|
||||||
|
@ -678,9 +678,8 @@ def select_and_download_models(opt: Namespace):
|
|||||||
|
|
||||||
# this is where the TUI is called
|
# this is where the TUI is called
|
||||||
else:
|
else:
|
||||||
# needed because the torch library is loaded, even though we don't use it
|
# needed to support the probe() method running under a subprocess
|
||||||
# currently commented out because it has started generating errors (?)
|
torch.multiprocessing.set_start_method("spawn")
|
||||||
# torch.multiprocessing.set_start_method("spawn")
|
|
||||||
|
|
||||||
# the third argument is needed in the Windows 11 environment in
|
# the third argument is needed in the Windows 11 environment in
|
||||||
# order to launch and resize a console window running this program
|
# order to launch and resize a console window running this program
|
||||||
|
@ -1,15 +1,16 @@
|
|||||||
import { Tooltip, useColorMode, useToken } from '@chakra-ui/react';
|
import { Tooltip, useColorMode, useToken } from '@chakra-ui/react';
|
||||||
import { MultiSelect, MultiSelectProps } from '@mantine/core';
|
import { MultiSelect, MultiSelectProps } from '@mantine/core';
|
||||||
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
|
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
|
||||||
import { memo } from 'react';
|
import { RefObject, memo } from 'react';
|
||||||
import { mode } from 'theme/util/mode';
|
import { mode } from 'theme/util/mode';
|
||||||
|
|
||||||
type IAIMultiSelectProps = MultiSelectProps & {
|
type IAIMultiSelectProps = MultiSelectProps & {
|
||||||
tooltip?: string;
|
tooltip?: string;
|
||||||
|
inputRef?: RefObject<HTMLInputElement>;
|
||||||
};
|
};
|
||||||
|
|
||||||
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
||||||
const { searchable = true, tooltip, ...rest } = props;
|
const { searchable = true, tooltip, inputRef, ...rest } = props;
|
||||||
const {
|
const {
|
||||||
base50,
|
base50,
|
||||||
base100,
|
base100,
|
||||||
@ -33,6 +34,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
|||||||
return (
|
return (
|
||||||
<Tooltip label={tooltip} placement="top" hasArrow>
|
<Tooltip label={tooltip} placement="top" hasArrow>
|
||||||
<MultiSelect
|
<MultiSelect
|
||||||
|
ref={inputRef}
|
||||||
searchable={searchable}
|
searchable={searchable}
|
||||||
styles={() => ({
|
styles={() => ({
|
||||||
label: {
|
label: {
|
||||||
|
@ -0,0 +1,33 @@
|
|||||||
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
|
import { memo } from 'react';
|
||||||
|
import { BiCode } from 'react-icons/bi';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
onClick: () => void;
|
||||||
|
};
|
||||||
|
|
||||||
|
const AddEmbeddingButton = (props: Props) => {
|
||||||
|
const { onClick } = props;
|
||||||
|
return (
|
||||||
|
<IAIIconButton
|
||||||
|
size="sm"
|
||||||
|
aria-label="Add Embedding"
|
||||||
|
tooltip="Add Embedding"
|
||||||
|
icon={<BiCode />}
|
||||||
|
sx={{
|
||||||
|
p: 2,
|
||||||
|
color: 'base.700',
|
||||||
|
_hover: {
|
||||||
|
color: 'base.550',
|
||||||
|
},
|
||||||
|
_active: {
|
||||||
|
color: 'base.500',
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
variant="link"
|
||||||
|
onClick={onClick}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(AddEmbeddingButton);
|
@ -0,0 +1,151 @@
|
|||||||
|
import {
|
||||||
|
Flex,
|
||||||
|
Popover,
|
||||||
|
PopoverBody,
|
||||||
|
PopoverContent,
|
||||||
|
PopoverTrigger,
|
||||||
|
Text,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
||||||
|
import { forEach } from 'lodash-es';
|
||||||
|
import {
|
||||||
|
PropsWithChildren,
|
||||||
|
forwardRef,
|
||||||
|
useCallback,
|
||||||
|
useMemo,
|
||||||
|
useRef,
|
||||||
|
} from 'react';
|
||||||
|
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
|
||||||
|
|
||||||
|
type EmbeddingSelectItem = {
|
||||||
|
label: string;
|
||||||
|
value: string;
|
||||||
|
description?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
type Props = PropsWithChildren & {
|
||||||
|
onSelect: (v: string) => void;
|
||||||
|
isOpen: boolean;
|
||||||
|
onClose: () => void;
|
||||||
|
};
|
||||||
|
|
||||||
|
const ParamEmbeddingPopover = (props: Props) => {
|
||||||
|
const { onSelect, isOpen, onClose, children } = props;
|
||||||
|
const { data: embeddingQueryData } = useGetTextualInversionModelsQuery();
|
||||||
|
const inputRef = useRef<HTMLInputElement>(null);
|
||||||
|
|
||||||
|
const data = useMemo(() => {
|
||||||
|
if (!embeddingQueryData) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const data: EmbeddingSelectItem[] = [];
|
||||||
|
|
||||||
|
forEach(embeddingQueryData.entities, (embedding, _) => {
|
||||||
|
if (!embedding) return;
|
||||||
|
|
||||||
|
data.push({
|
||||||
|
value: embedding.name,
|
||||||
|
label: embedding.name,
|
||||||
|
description: embedding.description,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}, [embeddingQueryData]);
|
||||||
|
|
||||||
|
const handleChange = useCallback(
|
||||||
|
(v: string[]) => {
|
||||||
|
if (v.length === 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
onSelect(v[0]);
|
||||||
|
},
|
||||||
|
[onSelect]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Popover
|
||||||
|
initialFocusRef={inputRef}
|
||||||
|
isOpen={isOpen}
|
||||||
|
onClose={onClose}
|
||||||
|
placement="bottom"
|
||||||
|
openDelay={0}
|
||||||
|
closeDelay={0}
|
||||||
|
closeOnBlur={true}
|
||||||
|
returnFocusOnClose={true}
|
||||||
|
>
|
||||||
|
<PopoverTrigger>{children}</PopoverTrigger>
|
||||||
|
<PopoverContent
|
||||||
|
sx={{
|
||||||
|
p: 0,
|
||||||
|
top: -1,
|
||||||
|
shadow: 'dark-lg',
|
||||||
|
borderColor: 'accent.300',
|
||||||
|
borderWidth: '2px',
|
||||||
|
borderStyle: 'solid',
|
||||||
|
_dark: { borderColor: 'accent.400' },
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<PopoverBody
|
||||||
|
sx={{ p: 0, w: `calc(${PARAMETERS_PANEL_WIDTH} - 2rem )` }}
|
||||||
|
>
|
||||||
|
{data.length === 0 ? (
|
||||||
|
<Flex sx={{ justifyContent: 'center', p: 2 }}>
|
||||||
|
<Text
|
||||||
|
sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}
|
||||||
|
>
|
||||||
|
No Embeddings Loaded
|
||||||
|
</Text>
|
||||||
|
</Flex>
|
||||||
|
) : (
|
||||||
|
<IAIMantineMultiSelect
|
||||||
|
inputRef={inputRef}
|
||||||
|
placeholder={'Add Embedding'}
|
||||||
|
value={[]}
|
||||||
|
data={data}
|
||||||
|
maxDropdownHeight={400}
|
||||||
|
nothingFound="No Matching Embeddings"
|
||||||
|
itemComponent={SelectItem}
|
||||||
|
disabled={data.length === 0}
|
||||||
|
filter={(value, selected, item: EmbeddingSelectItem) =>
|
||||||
|
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||||
|
item.value.toLowerCase().includes(value.toLowerCase().trim())
|
||||||
|
}
|
||||||
|
onChange={handleChange}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</PopoverBody>
|
||||||
|
</PopoverContent>
|
||||||
|
</Popover>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ParamEmbeddingPopover;
|
||||||
|
|
||||||
|
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
|
||||||
|
value: string;
|
||||||
|
label: string;
|
||||||
|
description?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
|
||||||
|
({ label, description, ...others }: ItemProps, ref) => {
|
||||||
|
return (
|
||||||
|
<div ref={ref} {...others}>
|
||||||
|
<div>
|
||||||
|
<Text>{label}</Text>
|
||||||
|
{description && (
|
||||||
|
<Text size="xs" color="base.600">
|
||||||
|
{description}
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
SelectItem.displayName = 'SelectItem';
|
@ -23,6 +23,7 @@ export const makeSelector = (image_name: string) =>
|
|||||||
({ gallery }) => {
|
({ gallery }) => {
|
||||||
const isSelected = gallery.selection.includes(image_name);
|
const isSelected = gallery.selection.includes(image_name);
|
||||||
const selectionCount = gallery.selection.length;
|
const selectionCount = gallery.selection.length;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
isSelected,
|
isSelected,
|
||||||
selectionCount,
|
selectionCount,
|
||||||
@ -117,7 +118,7 @@ const GalleryImage = (props: HoverableImageProps) => {
|
|||||||
resetIcon={<FaTrash />}
|
resetIcon={<FaTrash />}
|
||||||
resetTooltip="Delete image"
|
resetTooltip="Delete image"
|
||||||
imageSx={{ w: 'full', h: 'full' }}
|
imageSx={{ w: 'full', h: 'full' }}
|
||||||
withResetIcon
|
// withResetIcon // removed bc it's too easy to accidentally delete images
|
||||||
isDropDisabled={true}
|
isDropDisabled={true}
|
||||||
isUploadDisabled={true}
|
isUploadDisabled={true}
|
||||||
/>
|
/>
|
||||||
|
@ -4,7 +4,12 @@ import IAIIconButton from 'common/components/IAIIconButton';
|
|||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { FaTrash } from 'react-icons/fa';
|
import { FaTrash } from 'react-icons/fa';
|
||||||
import { Lora, loraRemoved, loraWeightChanged } from '../store/loraSlice';
|
import {
|
||||||
|
Lora,
|
||||||
|
loraRemoved,
|
||||||
|
loraWeightChanged,
|
||||||
|
loraWeightReset,
|
||||||
|
} from '../store/loraSlice';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
lora: Lora;
|
lora: Lora;
|
||||||
@ -22,7 +27,7 @@ const ParamLora = (props: Props) => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const handleReset = useCallback(() => {
|
const handleReset = useCallback(() => {
|
||||||
dispatch(loraWeightChanged({ id: lora.id, weight: 1 }));
|
dispatch(loraWeightReset(lora.id));
|
||||||
}, [dispatch, lora.id]);
|
}, [dispatch, lora.id]);
|
||||||
|
|
||||||
const handleRemoveLora = useCallback(() => {
|
const handleRemoveLora = useCallback(() => {
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { Text } from '@chakra-ui/react';
|
import { Flex, Text } from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
@ -61,6 +61,16 @@ const ParamLoraSelect = () => {
|
|||||||
[dispatch, lorasQueryData?.entities]
|
[dispatch, lorasQueryData?.entities]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
if (lorasQueryData?.ids.length === 0) {
|
||||||
|
return (
|
||||||
|
<Flex sx={{ justifyContent: 'center', p: 2 }}>
|
||||||
|
<Text sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}>
|
||||||
|
No LoRAs Loaded
|
||||||
|
</Text>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAIMantineMultiSelect
|
<IAIMantineMultiSelect
|
||||||
placeholder={data.length === 0 ? 'All LoRAs added' : 'Add LoRA'}
|
placeholder={data.length === 0 ? 'All LoRAs added' : 'Add LoRA'}
|
||||||
|
@ -8,7 +8,7 @@ export type Lora = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export const defaultLoRAConfig: Omit<Lora, 'id' | 'name'> = {
|
export const defaultLoRAConfig: Omit<Lora, 'id' | 'name'> = {
|
||||||
weight: 1,
|
weight: 0.75,
|
||||||
};
|
};
|
||||||
|
|
||||||
export type LoraState = {
|
export type LoraState = {
|
||||||
@ -38,9 +38,14 @@ export const loraSlice = createSlice({
|
|||||||
const { id, weight } = action.payload;
|
const { id, weight } = action.payload;
|
||||||
state.loras[id].weight = weight;
|
state.loras[id].weight = weight;
|
||||||
},
|
},
|
||||||
|
loraWeightReset: (state, action: PayloadAction<string>) => {
|
||||||
|
const id = action.payload;
|
||||||
|
state.loras[id].weight = defaultLoRAConfig.weight;
|
||||||
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
export const { loraAdded, loraRemoved, loraWeightChanged } = loraSlice.actions;
|
export const { loraAdded, loraRemoved, loraWeightChanged, loraWeightReset } =
|
||||||
|
loraSlice.actions;
|
||||||
|
|
||||||
export default loraSlice.reducer;
|
export default loraSlice.reducer;
|
||||||
|
@ -1,29 +1,107 @@
|
|||||||
import { FormControl } from '@chakra-ui/react';
|
import { Box, FormControl, useDisclosure } from '@chakra-ui/react';
|
||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAITextarea from 'common/components/IAITextarea';
|
import IAITextarea from 'common/components/IAITextarea';
|
||||||
|
import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton';
|
||||||
|
import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover';
|
||||||
import { setNegativePrompt } from 'features/parameters/store/generationSlice';
|
import { setNegativePrompt } from 'features/parameters/store/generationSlice';
|
||||||
|
import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react';
|
||||||
|
import { flushSync } from 'react-dom';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
const ParamNegativeConditioning = () => {
|
const ParamNegativeConditioning = () => {
|
||||||
const negativePrompt = useAppSelector(
|
const negativePrompt = useAppSelector(
|
||||||
(state: RootState) => state.generation.negativePrompt
|
(state: RootState) => state.generation.negativePrompt
|
||||||
);
|
);
|
||||||
|
const promptRef = useRef<HTMLTextAreaElement>(null);
|
||||||
|
const { isOpen, onClose, onOpen } = useDisclosure();
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const handleChangePrompt = useCallback(
|
||||||
|
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||||
|
dispatch(setNegativePrompt(e.target.value));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
const handleKeyDown = useCallback(
|
||||||
|
(e: KeyboardEvent<HTMLTextAreaElement>) => {
|
||||||
|
if (e.key === '<') {
|
||||||
|
onOpen();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[onOpen]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleSelectEmbedding = useCallback(
|
||||||
|
(v: string) => {
|
||||||
|
if (!promptRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// this is where we insert the TI trigger
|
||||||
|
const caret = promptRef.current.selectionStart;
|
||||||
|
|
||||||
|
if (caret === undefined) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let newPrompt = negativePrompt.slice(0, caret);
|
||||||
|
|
||||||
|
if (newPrompt[newPrompt.length - 1] !== '<') {
|
||||||
|
newPrompt += '<';
|
||||||
|
}
|
||||||
|
|
||||||
|
newPrompt += `${v}>`;
|
||||||
|
|
||||||
|
// we insert the cursor after the `>`
|
||||||
|
const finalCaretPos = newPrompt.length;
|
||||||
|
|
||||||
|
newPrompt += negativePrompt.slice(caret);
|
||||||
|
|
||||||
|
// must flush dom updates else selection gets reset
|
||||||
|
flushSync(() => {
|
||||||
|
dispatch(setNegativePrompt(newPrompt));
|
||||||
|
});
|
||||||
|
|
||||||
|
// set the caret position to just after the TI trigger promptRef.current.selectionStart = finalCaretPos;
|
||||||
|
promptRef.current.selectionEnd = finalCaretPos;
|
||||||
|
onClose();
|
||||||
|
},
|
||||||
|
[dispatch, onClose, negativePrompt]
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<FormControl>
|
<FormControl>
|
||||||
<IAITextarea
|
<ParamEmbeddingPopover
|
||||||
id="negativePrompt"
|
isOpen={isOpen}
|
||||||
name="negativePrompt"
|
onClose={onClose}
|
||||||
value={negativePrompt}
|
onSelect={handleSelectEmbedding}
|
||||||
onChange={(e) => dispatch(setNegativePrompt(e.target.value))}
|
>
|
||||||
placeholder={t('parameters.negativePromptPlaceholder')}
|
<IAITextarea
|
||||||
fontSize="sm"
|
id="negativePrompt"
|
||||||
minH={16}
|
name="negativePrompt"
|
||||||
/>
|
ref={promptRef}
|
||||||
|
value={negativePrompt}
|
||||||
|
placeholder={t('parameters.negativePromptPlaceholder')}
|
||||||
|
onChange={handleChangePrompt}
|
||||||
|
onKeyDown={handleKeyDown}
|
||||||
|
resize="vertical"
|
||||||
|
fontSize="sm"
|
||||||
|
minH={16}
|
||||||
|
/>
|
||||||
|
</ParamEmbeddingPopover>
|
||||||
|
{!isOpen && (
|
||||||
|
<Box
|
||||||
|
sx={{
|
||||||
|
position: 'absolute',
|
||||||
|
top: 0,
|
||||||
|
insetInlineEnd: 0,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<AddEmbeddingButton onClick={onOpen} />
|
||||||
|
</Box>
|
||||||
|
)}
|
||||||
</FormControl>
|
</FormControl>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { Box, FormControl } from '@chakra-ui/react';
|
import { Box, FormControl, useDisclosure } from '@chakra-ui/react';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react';
|
import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react';
|
||||||
@ -11,12 +11,15 @@ import {
|
|||||||
} from 'features/parameters/store/generationSlice';
|
} from 'features/parameters/store/generationSlice';
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
|
|
||||||
import { isEqual } from 'lodash-es';
|
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { userInvoked } from 'app/store/actions';
|
import { userInvoked } from 'app/store/actions';
|
||||||
import IAITextarea from 'common/components/IAITextarea';
|
import IAITextarea from 'common/components/IAITextarea';
|
||||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||||
|
import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton';
|
||||||
|
import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover';
|
||||||
|
import { isEqual } from 'lodash-es';
|
||||||
|
import { flushSync } from 'react-dom';
|
||||||
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
const promptInputSelector = createSelector(
|
const promptInputSelector = createSelector(
|
||||||
[(state: RootState) => state.generation, activeTabNameSelector],
|
[(state: RootState) => state.generation, activeTabNameSelector],
|
||||||
@ -40,14 +43,15 @@ const ParamPositiveConditioning = () => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { prompt, activeTabName } = useAppSelector(promptInputSelector);
|
const { prompt, activeTabName } = useAppSelector(promptInputSelector);
|
||||||
const isReady = useIsReadyToInvoke();
|
const isReady = useIsReadyToInvoke();
|
||||||
|
|
||||||
const promptRef = useRef<HTMLTextAreaElement>(null);
|
const promptRef = useRef<HTMLTextAreaElement>(null);
|
||||||
|
const { isOpen, onClose, onOpen } = useDisclosure();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
const handleChangePrompt = useCallback(
|
||||||
const handleChangePrompt = (e: ChangeEvent<HTMLTextAreaElement>) => {
|
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||||
dispatch(setPositivePrompt(e.target.value));
|
dispatch(setPositivePrompt(e.target.value));
|
||||||
};
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
'alt+a',
|
'alt+a',
|
||||||
@ -57,6 +61,45 @@ const ParamPositiveConditioning = () => {
|
|||||||
[]
|
[]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const handleSelectEmbedding = useCallback(
|
||||||
|
(v: string) => {
|
||||||
|
if (!promptRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// this is where we insert the TI trigger
|
||||||
|
const caret = promptRef.current.selectionStart;
|
||||||
|
|
||||||
|
if (caret === undefined) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let newPrompt = prompt.slice(0, caret);
|
||||||
|
|
||||||
|
if (newPrompt[newPrompt.length - 1] !== '<') {
|
||||||
|
newPrompt += '<';
|
||||||
|
}
|
||||||
|
|
||||||
|
newPrompt += `${v}>`;
|
||||||
|
|
||||||
|
// we insert the cursor after the `>`
|
||||||
|
const finalCaretPos = newPrompt.length;
|
||||||
|
|
||||||
|
newPrompt += prompt.slice(caret);
|
||||||
|
|
||||||
|
// must flush dom updates else selection gets reset
|
||||||
|
flushSync(() => {
|
||||||
|
dispatch(setPositivePrompt(newPrompt));
|
||||||
|
});
|
||||||
|
|
||||||
|
// set the caret position to just after the TI trigger
|
||||||
|
promptRef.current.selectionStart = finalCaretPos;
|
||||||
|
promptRef.current.selectionEnd = finalCaretPos;
|
||||||
|
onClose();
|
||||||
|
},
|
||||||
|
[dispatch, onClose, prompt]
|
||||||
|
);
|
||||||
|
|
||||||
const handleKeyDown = useCallback(
|
const handleKeyDown = useCallback(
|
||||||
(e: KeyboardEvent<HTMLTextAreaElement>) => {
|
(e: KeyboardEvent<HTMLTextAreaElement>) => {
|
||||||
if (e.key === 'Enter' && e.shiftKey === false && isReady) {
|
if (e.key === 'Enter' && e.shiftKey === false && isReady) {
|
||||||
@ -64,25 +107,50 @@ const ParamPositiveConditioning = () => {
|
|||||||
dispatch(clampSymmetrySteps());
|
dispatch(clampSymmetrySteps());
|
||||||
dispatch(userInvoked(activeTabName));
|
dispatch(userInvoked(activeTabName));
|
||||||
}
|
}
|
||||||
|
if (e.key === '<') {
|
||||||
|
onOpen();
|
||||||
|
}
|
||||||
},
|
},
|
||||||
[dispatch, activeTabName, isReady]
|
[isReady, dispatch, activeTabName, onOpen]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// const handleSelect = (e: MouseEvent<HTMLTextAreaElement>) => {
|
||||||
|
// const target = e.target as HTMLTextAreaElement;
|
||||||
|
// setCaret({ start: target.selectionStart, end: target.selectionEnd });
|
||||||
|
// };
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Box>
|
<Box>
|
||||||
<FormControl>
|
<FormControl>
|
||||||
<IAITextarea
|
<ParamEmbeddingPopover
|
||||||
id="prompt"
|
isOpen={isOpen}
|
||||||
name="prompt"
|
onClose={onClose}
|
||||||
placeholder={t('parameters.positivePromptPlaceholder')}
|
onSelect={handleSelectEmbedding}
|
||||||
value={prompt}
|
>
|
||||||
onChange={handleChangePrompt}
|
<IAITextarea
|
||||||
onKeyDown={handleKeyDown}
|
id="prompt"
|
||||||
resize="vertical"
|
name="prompt"
|
||||||
ref={promptRef}
|
ref={promptRef}
|
||||||
minH={32}
|
value={prompt}
|
||||||
/>
|
placeholder={t('parameters.positivePromptPlaceholder')}
|
||||||
|
onChange={handleChangePrompt}
|
||||||
|
onKeyDown={handleKeyDown}
|
||||||
|
resize="vertical"
|
||||||
|
minH={32}
|
||||||
|
/>
|
||||||
|
</ParamEmbeddingPopover>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
|
{!isOpen && (
|
||||||
|
<Box
|
||||||
|
sx={{
|
||||||
|
position: 'absolute',
|
||||||
|
top: 6,
|
||||||
|
insetInlineEnd: 0,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<AddEmbeddingButton onClick={onOpen} />
|
||||||
|
</Box>
|
||||||
|
)}
|
||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import { Tooltip } from '@chakra-ui/react';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIIconButton, {
|
import IAIIconButton, {
|
||||||
IAIIconButtonProps,
|
IAIIconButtonProps,
|
||||||
@ -25,26 +24,25 @@ const PinParametersPanelButton = (props: PinParametersPanelButtonProps) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Tooltip label={t('common.pinOptionsPanel')}>
|
<IAIIconButton
|
||||||
<IAIIconButton
|
{...props}
|
||||||
{...props}
|
tooltip={t('common.pinOptionsPanel')}
|
||||||
aria-label={t('common.pinOptionsPanel')}
|
aria-label={t('common.pinOptionsPanel')}
|
||||||
onClick={handleClickPinOptionsPanel}
|
onClick={handleClickPinOptionsPanel}
|
||||||
icon={shouldPinParametersPanel ? <BsPinAngleFill /> : <BsPinAngle />}
|
icon={shouldPinParametersPanel ? <BsPinAngleFill /> : <BsPinAngle />}
|
||||||
variant="ghost"
|
variant="ghost"
|
||||||
size="sm"
|
size="sm"
|
||||||
sx={{
|
sx={{
|
||||||
color: 'base.700',
|
color: 'base.700',
|
||||||
_hover: {
|
_hover: {
|
||||||
color: 'base.550',
|
color: 'base.550',
|
||||||
},
|
},
|
||||||
_active: {
|
_active: {
|
||||||
color: 'base.500',
|
color: 'base.500',
|
||||||
},
|
},
|
||||||
...sx,
|
...sx,
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
</Tooltip>
|
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||||
|
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
||||||
import { setActiveTabReducer } from './extraReducers';
|
import { setActiveTabReducer } from './extraReducers';
|
||||||
import { InvokeTabName } from './tabMap';
|
import { InvokeTabName } from './tabMap';
|
||||||
import { AddNewModelType, UIState } from './uiTypes';
|
import { AddNewModelType, UIState } from './uiTypes';
|
||||||
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
|
||||||
|
|
||||||
export const initialUIState: UIState = {
|
export const initialUIState: UIState = {
|
||||||
activeTab: 0,
|
activeTab: 0,
|
||||||
@ -19,6 +19,7 @@ export const initialUIState: UIState = {
|
|||||||
shouldShowGallery: true,
|
shouldShowGallery: true,
|
||||||
shouldHidePreview: false,
|
shouldHidePreview: false,
|
||||||
shouldShowProgressInViewer: true,
|
shouldShowProgressInViewer: true,
|
||||||
|
shouldShowEmbeddingPicker: false,
|
||||||
favoriteSchedulers: [],
|
favoriteSchedulers: [],
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -96,6 +97,9 @@ export const uiSlice = createSlice({
|
|||||||
) => {
|
) => {
|
||||||
state.favoriteSchedulers = action.payload;
|
state.favoriteSchedulers = action.payload;
|
||||||
},
|
},
|
||||||
|
toggleEmbeddingPicker: (state) => {
|
||||||
|
state.shouldShowEmbeddingPicker = !state.shouldShowEmbeddingPicker;
|
||||||
|
},
|
||||||
},
|
},
|
||||||
extraReducers(builder) {
|
extraReducers(builder) {
|
||||||
builder.addCase(initialImageChanged, (state) => {
|
builder.addCase(initialImageChanged, (state) => {
|
||||||
@ -122,6 +126,7 @@ export const {
|
|||||||
toggleGalleryPanel,
|
toggleGalleryPanel,
|
||||||
setShouldShowProgressInViewer,
|
setShouldShowProgressInViewer,
|
||||||
favoriteSchedulersChanged,
|
favoriteSchedulersChanged,
|
||||||
|
toggleEmbeddingPicker,
|
||||||
} = uiSlice.actions;
|
} = uiSlice.actions;
|
||||||
|
|
||||||
export default uiSlice.reducer;
|
export default uiSlice.reducer;
|
||||||
|
@ -27,5 +27,6 @@ export interface UIState {
|
|||||||
shouldPinGallery: boolean;
|
shouldPinGallery: boolean;
|
||||||
shouldShowGallery: boolean;
|
shouldShowGallery: boolean;
|
||||||
shouldShowProgressInViewer: boolean;
|
shouldShowProgressInViewer: boolean;
|
||||||
|
shouldShowEmbeddingPicker: boolean;
|
||||||
favoriteSchedulers: SchedulerParam[];
|
favoriteSchedulers: SchedulerParam[];
|
||||||
}
|
}
|
||||||
|
@ -1,18 +1,18 @@
|
|||||||
import { Middleware, MiddlewareAPI } from '@reduxjs/toolkit';
|
import { Middleware, MiddlewareAPI } from '@reduxjs/toolkit';
|
||||||
import { io, Socket } from 'socket.io-client';
|
import { Socket, io } from 'socket.io-client';
|
||||||
|
|
||||||
|
import { AppThunkDispatch, RootState } from 'app/store/store';
|
||||||
|
import { getTimestamp } from 'common/util/getTimestamp';
|
||||||
|
import { sessionCreated } from 'services/api/thunks/session';
|
||||||
import {
|
import {
|
||||||
ClientToServerEvents,
|
ClientToServerEvents,
|
||||||
ServerToClientEvents,
|
ServerToClientEvents,
|
||||||
} from 'services/events/types';
|
} from 'services/events/types';
|
||||||
import { socketSubscribed, socketUnsubscribed } from './actions';
|
import { socketSubscribed, socketUnsubscribed } from './actions';
|
||||||
import { AppThunkDispatch, RootState } from 'app/store/store';
|
|
||||||
import { getTimestamp } from 'common/util/getTimestamp';
|
|
||||||
import { sessionCreated } from 'services/api/thunks/session';
|
|
||||||
// import { OpenAPI } from 'services/api/types';
|
// import { OpenAPI } from 'services/api/types';
|
||||||
import { setEventListeners } from 'services/events/util/setEventListeners';
|
|
||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { $authToken, $baseUrl } from 'services/api/client';
|
import { $authToken, $baseUrl } from 'services/api/client';
|
||||||
|
import { setEventListeners } from 'services/events/util/setEventListeners';
|
||||||
|
|
||||||
const socketioLog = log.child({ namespace: 'socketio' });
|
const socketioLog = log.child({ namespace: 'socketio' });
|
||||||
|
|
||||||
@ -88,7 +88,7 @@ export const socketMiddleware = () => {
|
|||||||
socketSubscribed({
|
socketSubscribed({
|
||||||
sessionId: sessionId,
|
sessionId: sessionId,
|
||||||
timestamp: getTimestamp(),
|
timestamp: getTimestamp(),
|
||||||
boardId: getState().boards.selectedBoardId,
|
boardId: getState().gallery.selectedBoardId,
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user