feat: Add Style Prompts to Linear UI

This commit is contained in:
blessedcoolant 2023-07-25 18:17:30 +12:00 committed by psychedelicious
parent 9f94d0e52a
commit b0ebd148fa
10 changed files with 407 additions and 3 deletions

View File

@ -44,6 +44,8 @@ export const buildLinearSDXLImageToImageGraph = (
shouldUseNoiseSettings,
} = state.generation;
const { positiveStylePrompt, negativeStylePrompt } = state.sdxl;
// TODO: add batch functionality
// const {
// isEnabled: isBatchEnabled,
@ -90,11 +92,13 @@ export const buildLinearSDXLImageToImageGraph = (
type: 'sdxl_compel_prompt',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
style: positiveStylePrompt,
},
[NEGATIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
style: negativeStylePrompt,
},
[NOISE]: {
type: 'noise',

View File

@ -32,6 +32,8 @@ export const buildLinearSDXLTextToImageGraph = (
shouldUseNoiseSettings,
} = state.generation;
const { positiveStylePrompt, negativeStylePrompt } = state.sdxl;
const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise
: initialGenerationState.shouldUseCpuNoise;
@ -63,11 +65,13 @@ export const buildLinearSDXLTextToImageGraph = (
type: 'sdxl_compel_prompt',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
style: positiveStylePrompt,
},
[NEGATIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
style: negativeStylePrompt,
},
[NOISE]: {
type: 'noise',

View File

@ -0,0 +1,32 @@
import { RootState } from 'app/store/store';
import { MetadataAccumulatorInvocation } from 'services/api/types';
import { NonNullableGraph } from '../../types/types';
import { METADATA_ACCUMULATOR, SDXL_TEXT_TO_LATENTS } from './constants';
export const addSDXLRefinerToGraph = (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): void => {
const { shouldUseSDXLRefiner, model } = state.generation;
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (!shouldUseSDXLRefiner) return;
// Unplug SDXL Text To Latents To Latents To Image
graph.edges = graph.edges.filter(
(e) =>
!(
e.source.node_id === SDXL_TEXT_TO_LATENTS &&
['latents'].includes(e.source.field)
)
);
// graph.nodes[SDXL_REFINER_MODEL_LOADER] = {
// id: SDXL_REFINER_MODEL_LOADER,
// type: 'sdxl_refiner_model_loader',
// };
};

View File

@ -26,6 +26,7 @@ export const SCALE = 'scale_image';
export const SDXL_MODEL_LOADER = 'sdxl_model_loader';
export const SDXL_TEXT_TO_LATENTS = 't2l_sdxl';
export const SDXL_LATENTS_TO_LATENTS = 'l2l_sdxl';
export const SDXL_REFINER_MODEL_LOADER = 'sdxl_refiner_model_loader';
// friendly graph ids
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';

View File

@ -0,0 +1,149 @@
import { Box, FormControl, useDisclosure } from '@chakra-ui/react';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react';
import { createSelector } from '@reduxjs/toolkit';
import { clampSymmetrySteps } from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { userInvoked } from 'app/store/actions';
import IAITextarea from 'common/components/IAITextarea';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton';
import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { isEqual } from 'lodash-es';
import { flushSync } from 'react-dom';
import { setNegativeStylePromptSDXL } from '../store/sdxlSlice';
const promptInputSelector = createSelector(
[stateSelector, activeTabNameSelector],
({ sdxl }, activeTabName) => {
return {
prompt: sdxl.negativeStylePrompt,
activeTabName,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/**
* Prompt input text area.
*/
const ParamSDXLNegativeStyleConditioning = () => {
const dispatch = useAppDispatch();
const { prompt, activeTabName } = useAppSelector(promptInputSelector);
const isReady = useIsReadyToInvoke();
const promptRef = useRef<HTMLTextAreaElement>(null);
const { isOpen, onClose, onOpen } = useDisclosure();
const handleChangePrompt = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(setNegativeStylePromptSDXL(e.target.value));
},
[dispatch]
);
const handleSelectEmbedding = useCallback(
(v: string) => {
if (!promptRef.current) {
return;
}
// this is where we insert the TI trigger
const caret = promptRef.current.selectionStart;
if (caret === undefined) {
return;
}
let newPrompt = prompt.slice(0, caret);
if (newPrompt[newPrompt.length - 1] !== '<') {
newPrompt += '<';
}
newPrompt += `${v}>`;
// we insert the cursor after the `>`
const finalCaretPos = newPrompt.length;
newPrompt += prompt.slice(caret);
// must flush dom updates else selection gets reset
flushSync(() => {
dispatch(setNegativeStylePromptSDXL(newPrompt));
});
// set the caret position to just after the TI trigger
promptRef.current.selectionStart = finalCaretPos;
promptRef.current.selectionEnd = finalCaretPos;
onClose();
},
[dispatch, onClose, prompt]
);
const isEmbeddingEnabled = useFeatureStatus('embedding').isFeatureEnabled;
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLTextAreaElement>) => {
if (e.key === 'Enter' && e.shiftKey === false && isReady) {
e.preventDefault();
dispatch(clampSymmetrySteps());
dispatch(userInvoked(activeTabName));
}
if (isEmbeddingEnabled && e.key === '<') {
onOpen();
}
},
[isReady, dispatch, activeTabName, onOpen, isEmbeddingEnabled]
);
// const handleSelect = (e: MouseEvent<HTMLTextAreaElement>) => {
// const target = e.target as HTMLTextAreaElement;
// setCaret({ start: target.selectionStart, end: target.selectionEnd });
// };
return (
<Box position="relative">
<FormControl>
<ParamEmbeddingPopover
isOpen={isOpen}
onClose={onClose}
onSelect={handleSelectEmbedding}
>
<IAITextarea
id="prompt"
name="prompt"
ref={promptRef}
value={prompt}
placeholder="Negative Style Prompt"
onChange={handleChangePrompt}
onKeyDown={handleKeyDown}
resize="vertical"
fontSize="sm"
minH={16}
/>
</ParamEmbeddingPopover>
</FormControl>
{!isOpen && isEmbeddingEnabled && (
<Box
sx={{
position: 'absolute',
top: 0,
insetInlineEnd: 0,
}}
>
<AddEmbeddingButton onClick={onOpen} />
</Box>
)}
</Box>
);
};
export default ParamSDXLNegativeStyleConditioning;

View File

@ -0,0 +1,148 @@
import { Box, FormControl, useDisclosure } from '@chakra-ui/react';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react';
import { createSelector } from '@reduxjs/toolkit';
import { clampSymmetrySteps } from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { userInvoked } from 'app/store/actions';
import IAITextarea from 'common/components/IAITextarea';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton';
import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { isEqual } from 'lodash-es';
import { flushSync } from 'react-dom';
import { setPositiveStylePromptSDXL } from '../store/sdxlSlice';
const promptInputSelector = createSelector(
[stateSelector, activeTabNameSelector],
({ sdxl }, activeTabName) => {
return {
prompt: sdxl.positiveStylePrompt,
activeTabName,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/**
* Prompt input text area.
*/
const ParamSDXLPositiveStyleConditioning = () => {
const dispatch = useAppDispatch();
const { prompt, activeTabName } = useAppSelector(promptInputSelector);
const isReady = useIsReadyToInvoke();
const promptRef = useRef<HTMLTextAreaElement>(null);
const { isOpen, onClose, onOpen } = useDisclosure();
const handleChangePrompt = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(setPositiveStylePromptSDXL(e.target.value));
},
[dispatch]
);
const handleSelectEmbedding = useCallback(
(v: string) => {
if (!promptRef.current) {
return;
}
// this is where we insert the TI trigger
const caret = promptRef.current.selectionStart;
if (caret === undefined) {
return;
}
let newPrompt = prompt.slice(0, caret);
if (newPrompt[newPrompt.length - 1] !== '<') {
newPrompt += '<';
}
newPrompt += `${v}>`;
// we insert the cursor after the `>`
const finalCaretPos = newPrompt.length;
newPrompt += prompt.slice(caret);
// must flush dom updates else selection gets reset
flushSync(() => {
dispatch(setPositiveStylePromptSDXL(newPrompt));
});
// set the caret position to just after the TI trigger
promptRef.current.selectionStart = finalCaretPos;
promptRef.current.selectionEnd = finalCaretPos;
onClose();
},
[dispatch, onClose, prompt]
);
const isEmbeddingEnabled = useFeatureStatus('embedding').isFeatureEnabled;
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLTextAreaElement>) => {
if (e.key === 'Enter' && e.shiftKey === false && isReady) {
e.preventDefault();
dispatch(clampSymmetrySteps());
dispatch(userInvoked(activeTabName));
}
if (isEmbeddingEnabled && e.key === '<') {
onOpen();
}
},
[isReady, dispatch, activeTabName, onOpen, isEmbeddingEnabled]
);
// const handleSelect = (e: MouseEvent<HTMLTextAreaElement>) => {
// const target = e.target as HTMLTextAreaElement;
// setCaret({ start: target.selectionStart, end: target.selectionEnd });
// };
return (
<Box position="relative">
<FormControl>
<ParamEmbeddingPopover
isOpen={isOpen}
onClose={onClose}
onSelect={handleSelectEmbedding}
>
<IAITextarea
id="prompt"
name="prompt"
ref={promptRef}
value={prompt}
placeholder="Positive Style Prompt"
onChange={handleChangePrompt}
onKeyDown={handleKeyDown}
resize="vertical"
minH={16}
/>
</ParamEmbeddingPopover>
</FormControl>
{!isOpen && isEmbeddingEnabled && (
<Box
sx={{
position: 'absolute',
top: 0,
insetInlineEnd: 0,
}}
>
<AddEmbeddingButton onClick={onOpen} />
</Box>
)}
</Box>
);
};
export default ParamSDXLPositiveStyleConditioning;

View File

@ -0,0 +1,26 @@
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
// import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
import ImageToImageTabCoreParameters from 'features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters';
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
const SDXLImageToImageTabParameters = () => {
return (
<>
<ParamPositiveConditioning />
<ParamSDXLPositiveStyleConditioning />
<ParamNegativeConditioning />
<ParamSDXLNegativeStyleConditioning />
<ProcessButtons />
<ImageToImageTabCoreParameters />
<ParamDynamicPromptsCollapse />
<ParamNoiseCollapse />
</>
);
};
export default SDXLImageToImageTabParameters;

View File

@ -0,0 +1,25 @@
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
import TextToImageTabCoreParameters from 'features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters';
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
const SDXLTextToImageTabParameters = () => {
return (
<>
<ParamPositiveConditioning />
<ParamSDXLPositiveStyleConditioning />
<ParamNegativeConditioning />
<ParamSDXLNegativeStyleConditioning />
<ProcessButtons />
<TextToImageTabCoreParameters />
<ParamDynamicPromptsCollapse />
<ParamNoiseCollapse />
</>
);
};
export default SDXLTextToImageTabParameters;

View File

@ -1,7 +1,9 @@
import { Box, Flex } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import InitialImageDisplay from 'features/parameters/components/Parameters/ImageToImage/InitialImageDisplay';
import SDXLImageToImageTabParameters from 'features/sdxl/components/SDXLImageToImageTabParameters';
import { memo, useCallback, useRef } from 'react';
import {
ImperativePanelGroupHandle,
@ -16,6 +18,7 @@ import ImageToImageTabParameters from './ImageToImageTabParameters';
const ImageToImageTab = () => {
const dispatch = useAppDispatch();
const panelGroupRef = useRef<ImperativePanelGroupHandle>(null);
const model = useAppSelector((state: RootState) => state.generation.model);
const handleDoubleClickHandle = useCallback(() => {
if (!panelGroupRef.current) {
@ -28,7 +31,11 @@ const ImageToImageTab = () => {
return (
<Flex sx={{ gap: 4, w: 'full', h: 'full' }}>
<ParametersPinnedWrapper>
<ImageToImageTabParameters />
{model && model.base_model === 'sdxl' ? (
<SDXLImageToImageTabParameters />
) : (
<ImageToImageTabParameters />
)}
</ParametersPinnedWrapper>
<Box sx={{ w: 'full', h: 'full' }}>
<PanelGroup

View File

@ -1,14 +1,22 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import TextToImageSDXLTabParameters from 'features/sdxl/components/SDXLTextToImageTabParameters';
import { memo } from 'react';
import ParametersPinnedWrapper from '../../ParametersPinnedWrapper';
import TextToImageTabMain from './TextToImageTabMain';
import TextToImageTabParameters from './TextToImageTabParameters';
const TextToImageTab = () => {
const model = useAppSelector((state: RootState) => state.generation.model);
return (
<Flex sx={{ gap: 4, w: 'full', h: 'full' }}>
<ParametersPinnedWrapper>
<TextToImageTabParameters />
{model && model.base_model === 'sdxl' ? (
<TextToImageSDXLTabParameters />
) : (
<TextToImageTabParameters />
)}
</ParametersPinnedWrapper>
<TextToImageTabMain />
</Flex>