feat(ui): update UI for new metadata

- Update for new routes
- Update model storage in state to be `MainModelField` type instead of `string`, simplifies a lot of model handling
- Update model-related stuff for model `name` --> `model_name`
- Update linear graphs to use `MetadataAccumulator`
- Update `ImageMetadataViewer` UI
- Ensure all `recall` functions work (well, the ones that are active anyways)
This commit is contained in:
psychedelicious 2023-07-13 01:18:32 +10:00
parent bddc04af96
commit a43c900961
39 changed files with 1060 additions and 669 deletions

View File

@ -51,6 +51,7 @@ import {
} from './listeners/imageUrlsReceived'; } from './listeners/imageUrlsReceived';
import { addInitialImageSelectedListener } from './listeners/initialImageSelected'; import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
import { addModelSelectedListener } from './listeners/modelSelected'; import { addModelSelectedListener } from './listeners/modelSelected';
import { addModelsLoadedListener } from './listeners/modelsLoaded';
import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema'; import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema';
import { import {
addReceivedPageOfImagesFulfilledListener, addReceivedPageOfImagesFulfilledListener,
@ -224,3 +225,4 @@ addModelSelectedListener();
// app startup // app startup
addAppStartedListener(); addAppStartedListener();
addModelsLoadedListener();

View File

@ -1,13 +1,13 @@
import { startAppListening } from '..';
import { imageMetadataReceived } from 'services/api/thunks/image';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { controlNetImageProcessed } from 'features/controlNet/store/actions'; import { controlNetImageProcessed } from 'features/controlNet/store/actions';
import { Graph } from 'services/api/types';
import { sessionCreated } from 'services/api/thunks/session';
import { sessionReadyToInvoke } from 'features/system/store/actions';
import { socketInvocationComplete } from 'services/events/actions';
import { isImageOutput } from 'services/api/guards';
import { controlNetProcessedImageChanged } from 'features/controlNet/store/controlNetSlice'; import { controlNetProcessedImageChanged } from 'features/controlNet/store/controlNetSlice';
import { sessionReadyToInvoke } from 'features/system/store/actions';
import { isImageOutput } from 'services/api/guards';
import { imageDTOReceived } from 'services/api/thunks/image';
import { sessionCreated } from 'services/api/thunks/session';
import { Graph } from 'services/api/types';
import { socketInvocationComplete } from 'services/events/actions';
import { startAppListening } from '..';
const moduleLog = log.child({ namespace: 'controlNet' }); const moduleLog = log.child({ namespace: 'controlNet' });
@ -63,10 +63,8 @@ export const addControlNetImageProcessedListener = () => {
// Wait for the ImageDTO to be received // Wait for the ImageDTO to be received
const [imageMetadataReceivedAction] = await take( const [imageMetadataReceivedAction] = await take(
( (action): action is ReturnType<typeof imageDTOReceived.fulfilled> =>
action imageDTOReceived.fulfilled.match(action) &&
): action is ReturnType<typeof imageMetadataReceived.fulfilled> =>
imageMetadataReceived.fulfilled.match(action) &&
action.payload.image_name === image_name action.payload.image_name === image_name
); );
const processedControlImage = imageMetadataReceivedAction.payload; const processedControlImage = imageMetadataReceivedAction.payload;

View File

@ -1,7 +1,7 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { imageMetadataReceived } from 'services/api/thunks/image';
import { boardImagesApi } from 'services/api/endpoints/boardImages'; import { boardImagesApi } from 'services/api/endpoints/boardImages';
import { imageDTOReceived } from 'services/api/thunks/image';
import { startAppListening } from '..';
const moduleLog = log.child({ namespace: 'boards' }); const moduleLog = log.child({ namespace: 'boards' });
@ -17,7 +17,7 @@ export const addImageAddedToBoardFulfilledListener = () => {
); );
dispatch( dispatch(
imageMetadataReceived({ imageDTOReceived({
image_name, image_name,
}) })
); );

View File

@ -1,13 +1,13 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { imageMetadataReceived, imageUpdated } from 'services/api/thunks/image';
import { imageUpserted } from 'features/gallery/store/gallerySlice'; import { imageUpserted } from 'features/gallery/store/gallerySlice';
import { imageDTOReceived, imageUpdated } from 'services/api/thunks/image';
import { startAppListening } from '..';
const moduleLog = log.child({ namespace: 'image' }); const moduleLog = log.child({ namespace: 'image' });
export const addImageMetadataReceivedFulfilledListener = () => { export const addImageMetadataReceivedFulfilledListener = () => {
startAppListening({ startAppListening({
actionCreator: imageMetadataReceived.fulfilled, actionCreator: imageDTOReceived.fulfilled,
effect: (action, { getState, dispatch }) => { effect: (action, { getState, dispatch }) => {
const image = action.payload; const image = action.payload;
@ -40,7 +40,7 @@ export const addImageMetadataReceivedFulfilledListener = () => {
export const addImageMetadataReceivedRejectedListener = () => { export const addImageMetadataReceivedRejectedListener = () => {
startAppListening({ startAppListening({
actionCreator: imageMetadataReceived.rejected, actionCreator: imageDTOReceived.rejected,
effect: (action, { getState, dispatch }) => { effect: (action, { getState, dispatch }) => {
moduleLog.debug( moduleLog.debug(
{ data: { image: action.meta.arg } }, { data: { image: action.meta.arg } },

View File

@ -1,7 +1,7 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { imageMetadataReceived } from 'services/api/thunks/image';
import { boardImagesApi } from 'services/api/endpoints/boardImages'; import { boardImagesApi } from 'services/api/endpoints/boardImages';
import { imageDTOReceived } from 'services/api/thunks/image';
import { startAppListening } from '..';
const moduleLog = log.child({ namespace: 'boards' }); const moduleLog = log.child({ namespace: 'boards' });
@ -17,7 +17,7 @@ export const addImageRemovedFromBoardFulfilledListener = () => {
); );
dispatch( dispatch(
imageMetadataReceived({ imageDTOReceived({
image_name, image_name,
}) })
); );

View File

@ -14,7 +14,7 @@ export const addModelSelectedListener = () => {
actionCreator: modelSelected, actionCreator: modelSelected,
effect: (action, { getState, dispatch }) => { effect: (action, { getState, dispatch }) => {
const state = getState(); const state = getState();
const [base_model, type, name] = action.payload.split('/'); const { base_model, model_name } = action.payload;
if (state.generation.model?.base_model !== base_model) { if (state.generation.model?.base_model !== base_model) {
dispatch( dispatch(
@ -30,11 +30,7 @@ export const addModelSelectedListener = () => {
// TODO: controlnet cleared // TODO: controlnet cleared
} }
const newModel = zMainModel.parse({ const newModel = zMainModel.parse(action.payload);
id: action.payload,
base_model,
name,
});
dispatch(modelChanged(newModel)); dispatch(modelChanged(newModel));
}, },

View File

@ -0,0 +1,42 @@
import { modelChanged } from 'features/parameters/store/generationSlice';
import { some } from 'lodash-es';
import { modelsApi } from 'services/api/endpoints/models';
import { startAppListening } from '..';
export const addModelsLoadedListener = () => {
startAppListening({
matcher: modelsApi.endpoints.getMainModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// models loaded, we need to ensure the selected model is available and if not, select the first one
const currentModel = getState().generation.model;
const isCurrentModelAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === currentModel?.model_name &&
m?.base_model === currentModel?.base_model
);
if (isCurrentModelAvailable) {
return;
}
const firstModelId = action.payload.ids[0];
const firstModel = action.payload.entities[firstModelId];
if (!firstModel) {
// No models loaded at all
dispatch(modelChanged(null));
return;
}
dispatch(
modelChanged({
base_model: firstModel.base_model,
model_name: firstModel.model_name,
})
);
},
});
};

View File

@ -1,15 +1,15 @@
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
import { progressImageSet } from 'features/system/store/systemSlice';
import { boardImagesApi } from 'services/api/endpoints/boardImages';
import { isImageOutput } from 'services/api/guards';
import { imageDTOReceived } from 'services/api/thunks/image';
import { sessionCanceled } from 'services/api/thunks/session';
import { import {
appSocketInvocationComplete, appSocketInvocationComplete,
socketInvocationComplete, socketInvocationComplete,
} from 'services/events/actions'; } from 'services/events/actions';
import { imageMetadataReceived } from 'services/api/thunks/image'; import { startAppListening } from '../..';
import { sessionCanceled } from 'services/api/thunks/session';
import { isImageOutput } from 'services/api/guards';
import { progressImageSet } from 'features/system/store/systemSlice';
import { boardImagesApi } from 'services/api/endpoints/boardImages';
const moduleLog = log.child({ namespace: 'socketio' }); const moduleLog = log.child({ namespace: 'socketio' });
const nodeDenylist = ['dataURL_image']; const nodeDenylist = ['dataURL_image'];
@ -42,13 +42,13 @@ export const addInvocationCompleteEventListener = () => {
// Get its metadata // Get its metadata
dispatch( dispatch(
imageMetadataReceived({ imageDTOReceived({
image_name, image_name,
}) })
); );
const [{ payload: imageDTO }] = await take( const [{ payload: imageDTO }] = await take(
imageMetadataReceived.fulfilled.match imageDTOReceived.fulfilled.match
); );
// Handle canvas image // Handle canvas image

View File

@ -47,8 +47,8 @@ const ParamEmbeddingPopover = (props: Props) => {
const disabled = currentMainModel?.base_model !== embedding.base_model; const disabled = currentMainModel?.base_model !== embedding.base_model;
data.push({ data.push({
value: embedding.name, value: embedding.model_name,
label: embedding.name, label: embedding.model_name,
group: MODEL_TYPE_MAP[embedding.base_model], group: MODEL_TYPE_MAP[embedding.base_model],
disabled, disabled,
tooltip: disabled tooltip: disabled

View File

@ -118,7 +118,6 @@ const CurrentImagePreview = () => {
width: 'full', width: 'full',
height: 'full', height: 'full',
borderRadius: 'base', borderRadius: 'base',
overflow: 'scroll',
}} }}
> >
<ImageMetadataViewer image={imageDTO} /> <ImageMetadataViewer image={imageDTO} />

View File

@ -0,0 +1,212 @@
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { useCallback } from 'react';
import { UnsafeImageMetadata } from 'services/api/endpoints/images';
import MetadataItem from './MetadataItem';
type Props = {
metadata?: UnsafeImageMetadata['metadata'];
};
const ImageMetadataActions = (props: Props) => {
const { metadata } = props;
const {
recallBothPrompts,
recallPositivePrompt,
recallNegativePrompt,
recallSeed,
recallInitialImage,
recallCfgScale,
recallModel,
recallScheduler,
recallSteps,
recallWidth,
recallHeight,
recallStrength,
recallAllParameters,
} = useRecallParameters();
const handleRecallPositivePrompt = useCallback(() => {
recallPositivePrompt(metadata?.positive_prompt);
}, [metadata?.positive_prompt, recallPositivePrompt]);
const handleRecallNegativePrompt = useCallback(() => {
recallNegativePrompt(metadata?.negative_prompt);
}, [metadata?.negative_prompt, recallNegativePrompt]);
const handleRecallSeed = useCallback(() => {
recallSeed(metadata?.seed);
}, [metadata?.seed, recallSeed]);
const handleRecallModel = useCallback(() => {
recallModel(metadata?.model);
}, [metadata?.model, recallModel]);
const handleRecallWidth = useCallback(() => {
recallWidth(metadata?.width);
}, [metadata?.width, recallWidth]);
const handleRecallHeight = useCallback(() => {
recallHeight(metadata?.height);
}, [metadata?.height, recallHeight]);
const handleRecallScheduler = useCallback(() => {
recallScheduler(metadata?.scheduler);
}, [metadata?.scheduler, recallScheduler]);
const handleRecallSteps = useCallback(() => {
recallSteps(metadata?.steps);
}, [metadata?.steps, recallSteps]);
const handleRecallCfgScale = useCallback(() => {
recallCfgScale(metadata?.cfg_scale);
}, [metadata?.cfg_scale, recallCfgScale]);
const handleRecallStrength = useCallback(() => {
recallStrength(metadata?.strength);
}, [metadata?.strength, recallStrength]);
if (!metadata || Object.keys(metadata).length === 0) {
return null;
}
return (
<>
{metadata.generation_mode && (
<MetadataItem
label="Generation Mode"
value={metadata.generation_mode}
/>
)}
{metadata.positive_prompt && (
<MetadataItem
label="Positive Prompt"
labelPosition="top"
value={metadata.positive_prompt}
onClick={handleRecallPositivePrompt}
/>
)}
{metadata.negative_prompt && (
<MetadataItem
label="Negative Prompt"
labelPosition="top"
value={metadata.negative_prompt}
onClick={handleRecallNegativePrompt}
/>
)}
{metadata.seed !== undefined && (
<MetadataItem
label="Seed"
value={metadata.seed}
onClick={handleRecallSeed}
/>
)}
{metadata.model !== undefined && (
<MetadataItem
label="Model"
value={metadata.model.model_name}
onClick={handleRecallModel}
/>
)}
{metadata.width && (
<MetadataItem
label="Width"
value={metadata.width}
onClick={handleRecallWidth}
/>
)}
{metadata.height && (
<MetadataItem
label="Height"
value={metadata.height}
onClick={handleRecallHeight}
/>
)}
{/* {metadata.threshold !== undefined && (
<MetadataItem
label="Noise Threshold"
value={metadata.threshold}
onClick={() => dispatch(setThreshold(Number(metadata.threshold)))}
/>
)}
{metadata.perlin !== undefined && (
<MetadataItem
label="Perlin Noise"
value={metadata.perlin}
onClick={() => dispatch(setPerlin(Number(metadata.perlin)))}
/>
)} */}
{metadata.scheduler && (
<MetadataItem
label="Scheduler"
value={metadata.scheduler}
onClick={handleRecallScheduler}
/>
)}
{metadata.steps && (
<MetadataItem
label="Steps"
value={metadata.steps}
onClick={handleRecallSteps}
/>
)}
{metadata.cfg_scale !== undefined && (
<MetadataItem
label="CFG scale"
value={metadata.cfg_scale}
onClick={handleRecallCfgScale}
/>
)}
{/* {metadata.variations && metadata.variations.length > 0 && (
<MetadataItem
label="Seed-weight pairs"
value={seedWeightsToString(metadata.variations)}
onClick={() =>
dispatch(
setSeedWeights(seedWeightsToString(metadata.variations))
)
}
/>
)}
{metadata.seamless && (
<MetadataItem
label="Seamless"
value={metadata.seamless}
onClick={() => dispatch(setSeamless(metadata.seamless))}
/>
)}
{metadata.hires_fix && (
<MetadataItem
label="High Resolution Optimization"
value={metadata.hires_fix}
onClick={() => dispatch(setHiresFix(metadata.hires_fix))}
/>
)} */}
{/* {init_image_path && (
<MetadataItem
label="Initial image"
value={init_image_path}
isLink
onClick={() => dispatch(setInitialImage(init_image_path))}
/>
)} */}
{metadata.strength && (
<MetadataItem
label="Image to image strength"
value={metadata.strength}
onClick={handleRecallStrength}
/>
)}
{/* {metadata.fit && (
<MetadataItem
label="Image to image fit"
value={metadata.fit}
onClick={() => dispatch(setShouldFitToWidthHeight(metadata.fit))}
/>
)} */}
</>
);
};
export default ImageMetadataActions;

View File

@ -1,131 +1,63 @@
import { ExternalLinkIcon } from '@chakra-ui/icons'; import { ExternalLinkIcon } from '@chakra-ui/icons';
import { import {
Box,
Center,
Flex, Flex,
IconButton,
Link, Link,
Tab,
TabList,
TabPanel,
TabPanels,
Tabs,
Text, Text,
Tooltip,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks'; import { skipToken } from '@reduxjs/toolkit/dist/query';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { memo, useMemo } from 'react';
import { setShouldShowImageDetails } from 'features/ui/store/uiSlice'; import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { FaCopy } from 'react-icons/fa';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
import ImageMetadataActions from './ImageMetadataActions';
type MetadataItemProps = { import MetadataJSONViewer from './MetadataJSONViewer';
isLink?: boolean;
label: string;
onClick?: () => void;
value: number | string | boolean;
labelPosition?: string;
withCopy?: boolean;
};
/**
* Component to display an individual metadata item or parameter.
*/
const MetadataItem = ({
label,
value,
onClick,
isLink,
labelPosition,
withCopy = false,
}: MetadataItemProps) => {
const { t } = useTranslation();
if (!value) {
return null;
}
return (
<Flex gap={2}>
{onClick && (
<Tooltip label={`Recall ${label}`}>
<IconButton
aria-label={t('accessibility.useThisParameter')}
icon={<IoArrowUndoCircleOutline />}
size="xs"
variant="ghost"
fontSize={20}
onClick={onClick}
/>
</Tooltip>
)}
{withCopy && (
<Tooltip label={`Copy ${label}`}>
<IconButton
aria-label={`Copy ${label}`}
icon={<FaCopy />}
size="xs"
variant="ghost"
fontSize={14}
onClick={() => navigator.clipboard.writeText(value.toString())}
/>
</Tooltip>
)}
<Flex direction={labelPosition ? 'column' : 'row'}>
<Text fontWeight="semibold" whiteSpace="pre-wrap" pr={2}>
{label}:
</Text>
{isLink ? (
<Link href={value.toString()} isExternal wordBreak="break-all">
{value.toString()} <ExternalLinkIcon mx="2px" />
</Link>
) : (
<Text overflowY="scroll" wordBreak="break-all">
{value.toString()}
</Text>
)}
</Flex>
</Flex>
);
};
type ImageMetadataViewerProps = { type ImageMetadataViewerProps = {
image: ImageDTO; image: ImageDTO;
}; };
/**
* Image metadata viewer overlays currently selected image and provides
* access to any of its metadata for use in processing.
*/
const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => { const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
const dispatch = useAppDispatch(); // TODO: fix hotkeys
const { // const dispatch = useAppDispatch();
recallBothPrompts, // useHotkeys('esc', () => {
recallPositivePrompt, // dispatch(setShouldShowImageDetails(false));
recallNegativePrompt, // });
recallSeed,
recallInitialImage,
recallCfgScale,
recallModel,
recallScheduler,
recallSteps,
recallWidth,
recallHeight,
recallStrength,
recallAllParameters,
} = useRecallParameters();
useHotkeys('esc', () => { const { data } = useGetImageMetadataQuery(image?.image_name ?? skipToken);
dispatch(setShouldShowImageDetails(false)); const metadata = data?.metadata;
const tabData = useMemo(() => {
const _tabData: { label: string; data: object; copyTooltip: string }[] = [];
if (data?.metadata) {
_tabData.push({
label: 'Core Metadata',
data: data?.metadata,
copyTooltip: 'Copy Core Metadata JSON',
}); });
}
const sessionId = image?.session_id; if (image) {
_tabData.push({
label: 'Image Details',
data: image,
copyTooltip: 'Copy Image Details JSON',
});
}
const metadata = image?.metadata; if (data?.graph) {
_tabData.push({
const { t } = useTranslation(); label: 'Graph',
data: data?.graph,
const metadataJSON = JSON.stringify(image, null, 2); copyTooltip: 'Copy Graph JSON',
});
}
return _tabData;
}, [data?.metadata, data?.graph, image]);
return ( return (
<Flex <Flex
@ -136,11 +68,13 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
width: 'full', width: 'full',
height: 'full', height: 'full',
backdropFilter: 'blur(20px)', backdropFilter: 'blur(20px)',
bg: 'whiteAlpha.600', bg: 'baseAlpha.200',
_dark: { _dark: {
bg: 'blackAlpha.600', bg: 'blackAlpha.600',
}, },
overflow: 'scroll', borderRadius: 'base',
position: 'absolute',
overflow: 'hidden',
}} }}
> >
<Flex gap={2}> <Flex gap={2}>
@ -150,179 +84,42 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
<ExternalLinkIcon mx="2px" /> <ExternalLinkIcon mx="2px" />
</Link> </Link>
</Flex> </Flex>
{metadata && Object.keys(metadata).length > 0 ? (
<>
{metadata.type && (
<MetadataItem label="Invocation type" value={metadata.type} />
)}
{sessionId && <MetadataItem label="Session ID" value={sessionId} />}
{metadata.positive_conditioning && (
<MetadataItem
label="Positive Prompt"
labelPosition="top"
value={metadata.positive_conditioning}
onClick={() =>
recallPositivePrompt(metadata.positive_conditioning)
}
/>
)}
{metadata.negative_conditioning && (
<MetadataItem
label="Negative Prompt"
labelPosition="top"
value={metadata.negative_conditioning}
onClick={() =>
recallNegativePrompt(metadata.negative_conditioning)
}
/>
)}
{metadata.seed !== undefined && (
<MetadataItem
label="Seed"
value={metadata.seed}
onClick={() => recallSeed(metadata.seed)}
/>
)}
{metadata.model !== undefined && (
<MetadataItem
label="Model"
value={metadata.model}
onClick={() => recallModel(metadata.model)}
/>
)}
{metadata.width && (
<MetadataItem
label="Width"
value={metadata.width}
onClick={() => recallWidth(metadata.width)}
/>
)}
{metadata.height && (
<MetadataItem
label="Height"
value={metadata.height}
onClick={() => recallHeight(metadata.height)}
/>
)}
{/* {metadata.threshold !== undefined && (
<MetadataItem
label="Noise Threshold"
value={metadata.threshold}
onClick={() => dispatch(setThreshold(Number(metadata.threshold)))}
/>
)}
{metadata.perlin !== undefined && (
<MetadataItem
label="Perlin Noise"
value={metadata.perlin}
onClick={() => dispatch(setPerlin(Number(metadata.perlin)))}
/>
)} */}
{metadata.scheduler && (
<MetadataItem
label="Scheduler"
value={metadata.scheduler}
onClick={() => recallScheduler(metadata.scheduler)}
/>
)}
{metadata.steps && (
<MetadataItem
label="Steps"
value={metadata.steps}
onClick={() => recallSteps(metadata.steps)}
/>
)}
{metadata.cfg_scale !== undefined && (
<MetadataItem
label="CFG scale"
value={metadata.cfg_scale}
onClick={() => recallCfgScale(metadata.cfg_scale)}
/>
)}
{/* {metadata.variations && metadata.variations.length > 0 && (
<MetadataItem
label="Seed-weight pairs"
value={seedWeightsToString(metadata.variations)}
onClick={() =>
dispatch(
setSeedWeights(seedWeightsToString(metadata.variations))
)
}
/>
)}
{metadata.seamless && (
<MetadataItem
label="Seamless"
value={metadata.seamless}
onClick={() => dispatch(setSeamless(metadata.seamless))}
/>
)}
{metadata.hires_fix && (
<MetadataItem
label="High Resolution Optimization"
value={metadata.hires_fix}
onClick={() => dispatch(setHiresFix(metadata.hires_fix))}
/>
)} */}
{/* {init_image_path && ( <ImageMetadataActions metadata={metadata} />
<MetadataItem
label="Initial image" <Tabs
value={init_image_path} variant="line"
isLink sx={{ display: 'flex', flexDir: 'column', w: 'full', h: 'full' }}
onClick={() => dispatch(setInitialImage(init_image_path))} >
/> <TabList>
)} */} {tabData.map((tab) => (
{metadata.strength && ( <Tab
<MetadataItem key={tab.label}
label="Image to image strength"
value={metadata.strength}
onClick={() => recallStrength(metadata.strength)}
/>
)}
{/* {metadata.fit && (
<MetadataItem
label="Image to image fit"
value={metadata.fit}
onClick={() => dispatch(setShouldFitToWidthHeight(metadata.fit))}
/>
)} */}
</>
) : (
<Center width="100%" pt={10}>
<Text fontSize="lg" fontWeight="semibold">
No metadata available
</Text>
</Center>
)}
<Flex gap={2} direction="column" overflow="auto">
<Flex gap={2}>
<Tooltip label="Copy metadata JSON">
<IconButton
aria-label={t('accessibility.copyMetadataJson')}
icon={<FaCopy />}
size="xs"
variant="ghost"
fontSize={14}
onClick={() => navigator.clipboard.writeText(metadataJSON)}
/>
</Tooltip>
<Text fontWeight="semibold">Metadata JSON:</Text>
</Flex>
<OverlayScrollbarsComponent defer>
<Box
sx={{ sx={{
padding: 4, borderTopRadius: 'base',
borderRadius: 'base',
bg: 'whiteAlpha.500',
_dark: { bg: 'blackAlpha.500' },
w: 'full',
}} }}
> >
<pre>{metadataJSON}</pre> <Text sx={{ color: 'base.700', _dark: { color: 'base.300' } }}>
</Box> {tab.label}
</OverlayScrollbarsComponent> </Text>
</Flex> </Tab>
))}
</TabList>
<TabPanels sx={{ w: 'full', h: 'full' }}>
{tabData.map((tab) => (
<TabPanel
key={tab.label}
sx={{ w: 'full', h: 'full', p: 0, pt: 4 }}
>
<MetadataJSONViewer
jsonObject={tab.data}
copyTooltip={tab.copyTooltip}
/>
</TabPanel>
))}
</TabPanels>
</Tabs>
</Flex> </Flex>
); );
}; };

View File

@ -0,0 +1,77 @@
import { ExternalLinkIcon } from '@chakra-ui/icons';
import { Flex, IconButton, Link, Text, Tooltip } from '@chakra-ui/react';
import { useTranslation } from 'react-i18next';
import { FaCopy } from 'react-icons/fa';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
type MetadataItemProps = {
isLink?: boolean;
label: string;
onClick?: () => void;
value: number | string | boolean;
labelPosition?: string;
withCopy?: boolean;
};
/**
* Component to display an individual metadata item or parameter.
*/
const MetadataItem = ({
label,
value,
onClick,
isLink,
labelPosition,
withCopy = false,
}: MetadataItemProps) => {
const { t } = useTranslation();
if (!value) {
return null;
}
return (
<Flex gap={2}>
{onClick && (
<Tooltip label={`Recall ${label}`}>
<IconButton
aria-label={t('accessibility.useThisParameter')}
icon={<IoArrowUndoCircleOutline />}
size="xs"
variant="ghost"
fontSize={20}
onClick={onClick}
/>
</Tooltip>
)}
{withCopy && (
<Tooltip label={`Copy ${label}`}>
<IconButton
aria-label={`Copy ${label}`}
icon={<FaCopy />}
size="xs"
variant="ghost"
fontSize={14}
onClick={() => navigator.clipboard.writeText(value.toString())}
/>
</Tooltip>
)}
<Flex direction={labelPosition ? 'column' : 'row'}>
<Text fontWeight="semibold" whiteSpace="pre-wrap" pr={2}>
{label}:
</Text>
{isLink ? (
<Link href={value.toString()} isExternal wordBreak="break-all">
{value.toString()} <ExternalLinkIcon mx="2px" />
</Link>
) : (
<Text overflowY="scroll" wordBreak="break-all">
{value.toString()}
</Text>
)}
</Flex>
</Flex>
);
};
export default MetadataItem;

View File

@ -0,0 +1,70 @@
import { Box, Flex, IconButton, Tooltip } from '@chakra-ui/react';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { useMemo } from 'react';
import { FaCopy } from 'react-icons/fa';
type Props = {
copyTooltip: string;
jsonObject: object;
};
const MetadataJSONViewer = (props: Props) => {
const { copyTooltip, jsonObject } = props;
const jsonString = useMemo(
() => JSON.stringify(jsonObject, null, 2),
[jsonObject]
);
return (
<Flex
sx={{
borderRadius: 'base',
bg: 'whiteAlpha.500',
_dark: { bg: 'blackAlpha.500' },
flexGrow: 1,
w: 'full',
h: 'full',
position: 'relative',
}}
>
<Box
sx={{
position: 'absolute',
top: 0,
left: 0,
right: 0,
bottom: 0,
overflow: 'auto',
p: 4,
}}
>
<OverlayScrollbarsComponent
defer
style={{ height: '100%', width: '100%' }}
options={{
scrollbars: {
visibility: 'auto',
autoHide: 'move',
autoHideDelay: 1300,
theme: 'os-theme-dark',
},
}}
>
<pre>{jsonString}</pre>
</OverlayScrollbarsComponent>
</Box>
<Flex sx={{ position: 'absolute', top: 0, insetInlineEnd: 0, p: 2 }}>
<Tooltip label={copyTooltip}>
<IconButton
aria-label={copyTooltip}
icon={<FaCopy />}
variant="ghost"
onClick={() => navigator.clipboard.writeText(jsonString)}
/>
</Tooltip>
</Flex>
</Flex>
);
};
export default MetadataJSONViewer;

View File

@ -45,7 +45,7 @@ const ParamLoraSelect = () => {
data.push({ data.push({
value: id, value: id,
label: lora.name, label: lora.model_name,
disabled, disabled,
group: MODEL_TYPE_MAP[lora.base_model], group: MODEL_TYPE_MAP[lora.base_model],
tooltip: disabled tooltip: disabled

View File

@ -1,94 +0,0 @@
import { RootState } from 'app/store/store';
import { getValidControlNets } from 'features/controlNet/util/getValidControlNets';
import { CollectInvocation, ControlNetInvocation } from 'services/api/types';
import { NonNullableGraph } from '../types/types';
import { CONTROL_NET_COLLECT } from './graphBuilders/constants';
export const addControlNetToLinearGraph = (
graph: NonNullableGraph,
baseNodeId: string,
state: RootState
): void => {
const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet;
const validControlNets = getValidControlNets(controlNets);
if (isControlNetEnabled && Boolean(validControlNets.length)) {
if (validControlNets.length > 1) {
// We have multiple controlnets, add ControlNet collector
const controlNetIterateNode: CollectInvocation = {
id: CONTROL_NET_COLLECT,
type: 'collect',
};
graph.nodes[controlNetIterateNode.id] = controlNetIterateNode;
graph.edges.push({
source: { node_id: controlNetIterateNode.id, field: 'collection' },
destination: {
node_id: baseNodeId,
field: 'control',
},
});
}
validControlNets.forEach((controlNet) => {
const {
controlNetId,
controlImage,
processedControlImage,
beginStepPct,
endStepPct,
controlMode,
model,
processorType,
weight,
} = controlNet;
const controlNetNode: ControlNetInvocation = {
id: `control_net_${controlNetId}`,
type: 'controlnet',
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
control_mode: controlMode,
control_model: model as ControlNetInvocation['control_model'],
control_weight: weight,
};
if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image
controlNetNode.image = {
image_name: processedControlImage,
};
} else if (controlImage) {
// The control image is preprocessed
controlNetNode.image = {
image_name: controlImage,
};
} else {
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
return;
}
graph.nodes[controlNetNode.id] = controlNetNode;
if (validControlNets.length > 1) {
// if we have multiple controlnets, link to the collector
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
node_id: CONTROL_NET_COLLECT,
field: 'item',
},
});
} else {
// otherwise, link directly to the base node
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
node_id: baseNodeId,
field: 'control',
},
});
}
});
}
};

View File

@ -1,40 +0,0 @@
import {
Edge,
ImageToImageInvocation,
InpaintInvocation,
IterateInvocation,
RandomRangeInvocation,
RangeInvocation,
TextToImageInvocation,
} from 'services/api/types';
export const buildEdges = (
baseNode: TextToImageInvocation | ImageToImageInvocation | InpaintInvocation,
rangeNode: RangeInvocation | RandomRangeInvocation,
iterateNode: IterateInvocation
): Edge[] => {
const edges: Edge[] = [
{
source: {
node_id: rangeNode.id,
field: 'collection',
},
destination: {
node_id: iterateNode.id,
field: 'collection',
},
},
{
source: {
node_id: iterateNode.id,
field: 'item',
},
destination: {
node_id: baseNode.id,
field: 'seed',
},
},
];
return edges;
};

View File

@ -0,0 +1,100 @@
import { RootState } from 'app/store/store';
import { getValidControlNets } from 'features/controlNet/util/getValidControlNets';
import { omit } from 'lodash-es';
import {
CollectInvocation,
ControlField,
ControlNetInvocation,
MetadataAccumulatorInvocation,
} from 'services/api/types';
import { NonNullableGraph } from '../../types/types';
import { CONTROL_NET_COLLECT, METADATA_ACCUMULATOR } from './constants';
export const addControlNetToLinearGraph = (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): void => {
const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet;
const validControlNets = getValidControlNets(controlNets);
const metadataAccumulator = graph.nodes[
METADATA_ACCUMULATOR
] as MetadataAccumulatorInvocation;
if (isControlNetEnabled && Boolean(validControlNets.length)) {
if (validControlNets.length) {
// We have multiple controlnets, add ControlNet collector
const controlNetIterateNode: CollectInvocation = {
id: CONTROL_NET_COLLECT,
type: 'collect',
};
graph.nodes[CONTROL_NET_COLLECT] = controlNetIterateNode;
graph.edges.push({
source: { node_id: CONTROL_NET_COLLECT, field: 'collection' },
destination: {
node_id: baseNodeId,
field: 'control',
},
});
validControlNets.forEach((controlNet) => {
const {
controlNetId,
controlImage,
processedControlImage,
beginStepPct,
endStepPct,
controlMode,
model,
processorType,
weight,
} = controlNet;
const controlNetNode: ControlNetInvocation = {
id: `control_net_${controlNetId}`,
type: 'controlnet',
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
control_mode: controlMode,
control_model: model as ControlNetInvocation['control_model'],
control_weight: weight,
};
if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image
controlNetNode.image = {
image_name: processedControlImage,
};
} else if (controlImage) {
// The control image is preprocessed
controlNetNode.image = {
image_name: controlImage,
};
} else {
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
return;
}
graph.nodes[controlNetNode.id] = controlNetNode;
// metadata accumulator only needs a control field - not the whole node
// extract what we need and add to the accumulator
const controlField = omit(controlNetNode, [
'id',
'type',
]) as ControlField;
metadataAccumulator.controlnets.push(controlField);
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
node_id: CONTROL_NET_COLLECT,
field: 'item',
},
});
});
}
}
};

View File

@ -1,8 +1,10 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { unset } from 'lodash-es';
import { import {
DynamicPromptInvocation, DynamicPromptInvocation,
IterateInvocation, IterateInvocation,
MetadataAccumulatorInvocation,
NoiseInvocation, NoiseInvocation,
RandomIntInvocation, RandomIntInvocation,
RangeOfSizeInvocation, RangeOfSizeInvocation,
@ -10,16 +12,16 @@ import {
import { import {
DYNAMIC_PROMPT, DYNAMIC_PROMPT,
ITERATE, ITERATE,
METADATA_ACCUMULATOR,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RANDOM_INT, RANDOM_INT,
RANGE_OF_SIZE, RANGE_OF_SIZE,
} from './constants'; } from './constants';
import { unset } from 'lodash-es';
export const addDynamicPromptsToGraph = ( export const addDynamicPromptsToGraph = (
graph: NonNullableGraph, state: RootState,
state: RootState graph: NonNullableGraph
): void => { ): void => {
const { positivePrompt, iterations, seed, shouldRandomizeSeed } = const { positivePrompt, iterations, seed, shouldRandomizeSeed } =
state.generation; state.generation;
@ -30,6 +32,10 @@ export const addDynamicPromptsToGraph = (
maxPrompts, maxPrompts,
} = state.dynamicPrompts; } = state.dynamicPrompts;
const metadataAccumulator = graph.nodes[
METADATA_ACCUMULATOR
] as MetadataAccumulatorInvocation;
if (isDynamicPromptsEnabled) { if (isDynamicPromptsEnabled) {
// iteration is handled via dynamic prompts // iteration is handled via dynamic prompts
unset(graph.nodes[POSITIVE_CONDITIONING], 'prompt'); unset(graph.nodes[POSITIVE_CONDITIONING], 'prompt');
@ -74,6 +80,18 @@ export const addDynamicPromptsToGraph = (
} }
); );
// hook up positive prompt to metadata
graph.edges.push({
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: METADATA_ACCUMULATOR,
field: 'positive_prompt',
},
});
if (shouldRandomizeSeed) { if (shouldRandomizeSeed) {
// Random int node to generate the starting seed // Random int node to generate the starting seed
const randomIntNode: RandomIntInvocation = { const randomIntNode: RandomIntInvocation = {
@ -88,11 +106,22 @@ export const addDynamicPromptsToGraph = (
source: { node_id: RANDOM_INT, field: 'a' }, source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: NOISE, field: 'seed' }, destination: { node_id: NOISE, field: 'seed' },
}); });
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: METADATA_ACCUMULATOR, field: 'seed' },
});
} else { } else {
// User specified seed, so set the start of the range of size to the seed // User specified seed, so set the start of the range of size to the seed
(graph.nodes[NOISE] as NoiseInvocation).seed = seed; (graph.nodes[NOISE] as NoiseInvocation).seed = seed;
// hook up seed to metadata
metadataAccumulator.seed = seed;
} }
} else { } else {
// no dynamic prompt - hook up positive prompt
metadataAccumulator.positive_prompt = positivePrompt;
const rangeOfSizeNode: RangeOfSizeInvocation = { const rangeOfSizeNode: RangeOfSizeInvocation = {
id: RANGE_OF_SIZE, id: RANGE_OF_SIZE,
type: 'range_of_size', type: 'range_of_size',
@ -130,6 +159,18 @@ export const addDynamicPromptsToGraph = (
}, },
}); });
// hook up seed to metadata
graph.edges.push({
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: METADATA_ACCUMULATOR,
field: 'seed',
},
});
// handle seed // handle seed
if (shouldRandomizeSeed) { if (shouldRandomizeSeed) {
// Random int node to generate the starting seed // Random int node to generate the starting seed

View File

@ -1,19 +1,23 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { forEach, size } from 'lodash-es'; import { forEach, size } from 'lodash-es';
import { LoraLoaderInvocation } from 'services/api/types'; import {
LoraLoaderInvocation,
MetadataAccumulatorInvocation,
} from 'services/api/types';
import { modelIdToLoRAModelField } from '../modelIdToLoRAName'; import { modelIdToLoRAModelField } from '../modelIdToLoRAName';
import { import {
CLIP_SKIP, CLIP_SKIP,
LORA_LOADER, LORA_LOADER,
MAIN_MODEL_LOADER, MAIN_MODEL_LOADER,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
} from './constants'; } from './constants';
export const addLoRAsToGraph = ( export const addLoRAsToGraph = (
graph: NonNullableGraph,
state: RootState, state: RootState,
graph: NonNullableGraph,
baseNodeId: string baseNodeId: string
): void => { ): void => {
/** /**
@ -26,6 +30,9 @@ export const addLoRAsToGraph = (
const { loras } = state.lora; const { loras } = state.lora;
const loraCount = size(loras); const loraCount = size(loras);
const metadataAccumulator = graph.nodes[
METADATA_ACCUMULATOR
] as MetadataAccumulatorInvocation;
if (loraCount > 0) { if (loraCount > 0) {
// Remove MAIN_MODEL_LOADER unet connection to feed it to LoRAs // Remove MAIN_MODEL_LOADER unet connection to feed it to LoRAs
@ -62,6 +69,10 @@ export const addLoRAsToGraph = (
weight, weight,
}; };
// add the lora to the metadata accumulator
metadataAccumulator.loras.push({ lora: loraField, weight });
// add to graph
graph.nodes[currentLoraNodeId] = loraLoaderNode; graph.nodes[currentLoraNodeId] = loraLoaderNode;
if (currentLoraIndex === 0) { if (currentLoraIndex === 0) {

View File

@ -1,5 +1,6 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { MetadataAccumulatorInvocation } from 'services/api/types';
import { modelIdToVAEModelField } from '../modelIdToVAEModelField'; import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
import { import {
IMAGE_TO_IMAGE_GRAPH, IMAGE_TO_IMAGE_GRAPH,
@ -8,18 +9,22 @@ import {
INPAINT_GRAPH, INPAINT_GRAPH,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
MAIN_MODEL_LOADER, MAIN_MODEL_LOADER,
METADATA_ACCUMULATOR,
TEXT_TO_IMAGE_GRAPH, TEXT_TO_IMAGE_GRAPH,
VAE_LOADER, VAE_LOADER,
} from './constants'; } from './constants';
export const addVAEToGraph = ( export const addVAEToGraph = (
graph: NonNullableGraph, state: RootState,
state: RootState graph: NonNullableGraph
): void => { ): void => {
const { vae } = state.generation; const { vae } = state.generation;
const vae_model = modelIdToVAEModelField(vae?.id || ''); const vae_model = modelIdToVAEModelField(vae?.id || '');
const isAutoVae = !vae; const isAutoVae = !vae;
const metadataAccumulator = graph.nodes[
METADATA_ACCUMULATOR
] as MetadataAccumulatorInvocation;
if (!isAutoVae) { if (!isAutoVae) {
graph.nodes[VAE_LOADER] = { graph.nodes[VAE_LOADER] = {
@ -67,4 +72,8 @@ export const addVAEToGraph = (
}, },
}); });
} }
if (vae) {
metadataAccumulator.vae = vae_model;
}
}; };

View File

@ -7,8 +7,7 @@ import {
ImageResizeInvocation, ImageResizeInvocation,
ImageToLatentsInvocation, ImageToLatentsInvocation,
} from 'services/api/types'; } from 'services/api/types';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
@ -19,6 +18,7 @@ import {
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
LATENTS_TO_LATENTS, LATENTS_TO_LATENTS,
MAIN_MODEL_LOADER, MAIN_MODEL_LOADER,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
@ -37,7 +37,7 @@ export const buildCanvasImageToImageGraph = (
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: currentModel, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -50,7 +50,10 @@ export const buildCanvasImageToImageGraph = (
// The bounding box determines width and height, not the width and height params // The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions; const { width, height } = state.canvas.boundingBoxDimensions;
const model = modelIdToMainModelField(currentModel?.id || ''); if (!model) {
moduleLog.error('No model found in state');
throw new Error('No model found in state');
}
const use_cpu = shouldUseNoiseSettings const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise ? shouldUseCpuNoise
@ -275,16 +278,51 @@ export const buildCanvasImageToImageGraph = (
}); });
} }
addLoRAsToGraph(graph, state, LATENTS_TO_LATENTS); // add metadata accumulator, which is only mostly populated - some fields are added later
graph.nodes[METADATA_ACCUMULATOR] = {
id: METADATA_ACCUMULATOR,
type: 'metadata_accumulator',
generation_mode: 'img2img',
cfg_scale,
height,
width,
positive_prompt: '', // set in addDynamicPromptsToGraph
negative_prompt: negativePrompt,
model,
seed: 0, // set in addDynamicPromptsToGraph
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
scheduler,
vae: undefined, // option; set in addVAEToGraph
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
clip_skip: clipSkip,
strength,
init_image: initialImage.image_name,
};
// Add VAE graph.edges.push({
addVAEToGraph(graph, state); source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'metadata',
},
});
// add dynamic prompts, mutating `graph` // add LoRA support
addDynamicPromptsToGraph(graph, state); addLoRAsToGraph(state, graph, LATENTS_TO_LATENTS);
// optionally add custom VAE
addVAEToGraph(state, graph);
// add dynamic prompts - also sets up core iteration and seed
addDynamicPromptsToGraph(state, graph);
// add controlnet, mutating `graph` // add controlnet, mutating `graph`
addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state); addControlNetToLinearGraph(state, graph, LATENTS_TO_LATENTS);
return graph; return graph;
}; };

View File

@ -212,10 +212,10 @@ export const buildCanvasInpaintGraph = (
], ],
}; };
addLoRAsToGraph(graph, state, INPAINT); addLoRAsToGraph(state, graph, INPAINT);
// Add VAE // Add VAE
addVAEToGraph(graph, state); addVAEToGraph(state, graph);
// handle seed // handle seed
if (shouldRandomizeSeed) { if (shouldRandomizeSeed) {

View File

@ -1,8 +1,8 @@
import { log } from 'app/logging/useLogger';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { initialGenerationState } from 'features/parameters/store/generationSlice'; import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
@ -10,6 +10,7 @@ import {
CLIP_SKIP, CLIP_SKIP,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
MAIN_MODEL_LOADER, MAIN_MODEL_LOADER,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
@ -17,6 +18,8 @@ import {
TEXT_TO_LATENTS, TEXT_TO_LATENTS,
} from './constants'; } from './constants';
const moduleLog = log.child({ namespace: 'nodes' });
/** /**
* Builds the Canvas tab's Text to Image graph. * Builds the Canvas tab's Text to Image graph.
*/ */
@ -26,7 +29,7 @@ export const buildCanvasTextToImageGraph = (
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: currentModel, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -38,7 +41,10 @@ export const buildCanvasTextToImageGraph = (
// The bounding box determines width and height, not the width and height params // The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions; const { width, height } = state.canvas.boundingBoxDimensions;
const model = modelIdToMainModelField(currentModel?.id || ''); if (!model) {
moduleLog.error('No model found in state');
throw new Error('No model found in state');
}
const use_cpu = shouldUseNoiseSettings const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise ? shouldUseCpuNoise
@ -180,16 +186,49 @@ export const buildCanvasTextToImageGraph = (
], ],
}; };
addLoRAsToGraph(graph, state, TEXT_TO_LATENTS); // add metadata accumulator, which is only mostly populated - some fields are added later
graph.nodes[METADATA_ACCUMULATOR] = {
id: METADATA_ACCUMULATOR,
type: 'metadata_accumulator',
generation_mode: 'txt2img',
cfg_scale,
height,
width,
positive_prompt: '', // set in addDynamicPromptsToGraph
negative_prompt: negativePrompt,
model,
seed: 0, // set in addDynamicPromptsToGraph
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
scheduler,
vae: undefined, // option; set in addVAEToGraph
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
clip_skip: clipSkip,
};
// Add VAE graph.edges.push({
addVAEToGraph(graph, state); source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'metadata',
},
});
// add dynamic prompts, mutating `graph` // add LoRA support
addDynamicPromptsToGraph(graph, state); addLoRAsToGraph(state, graph, TEXT_TO_LATENTS);
// optionally add custom VAE
addVAEToGraph(state, graph);
// add dynamic prompts - also sets up core iteration and seed
addDynamicPromptsToGraph(state, graph);
// add controlnet, mutating `graph` // add controlnet, mutating `graph`
addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state); addControlNetToLinearGraph(state, graph, TEXT_TO_LATENTS);
return graph; return graph;
}; };

View File

@ -3,25 +3,21 @@ import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { initialGenerationState } from 'features/parameters/store/generationSlice'; import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { import {
ImageCollectionInvocation,
ImageResizeInvocation, ImageResizeInvocation,
ImageToLatentsInvocation, ImageToLatentsInvocation,
IterateInvocation,
} from 'services/api/types'; } from 'services/api/types';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
import { import {
CLIP_SKIP, CLIP_SKIP,
IMAGE_COLLECTION,
IMAGE_COLLECTION_ITERATE,
IMAGE_TO_IMAGE_GRAPH, IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS, IMAGE_TO_LATENTS,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
LATENTS_TO_LATENTS, LATENTS_TO_LATENTS,
MAIN_MODEL_LOADER, MAIN_MODEL_LOADER,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
@ -39,7 +35,7 @@ export const buildLinearImageToImageGraph = (
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: currentModel, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -53,14 +49,15 @@ export const buildLinearImageToImageGraph = (
shouldUseNoiseSettings, shouldUseNoiseSettings,
} = state.generation; } = state.generation;
const { // TODO: add batch functionality
isEnabled: isBatchEnabled, // const {
imageNames: batchImageNames, // isEnabled: isBatchEnabled,
asInitialImage, // imageNames: batchImageNames,
} = state.batch; // asInitialImage,
// } = state.batch;
const shouldBatch = // const shouldBatch =
isBatchEnabled && batchImageNames.length > 0 && asInitialImage; // isBatchEnabled && batchImageNames.length > 0 && asInitialImage;
/** /**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
@ -71,12 +68,15 @@ export const buildLinearImageToImageGraph = (
* the `fit` param. These are added to the graph at the end. * the `fit` param. These are added to the graph at the end.
*/ */
if (!initialImage && !shouldBatch) { if (!initialImage) {
moduleLog.error('No initial image found in state'); moduleLog.error('No initial image found in state');
throw new Error('No initial image found in state'); throw new Error('No initial image found in state');
} }
const model = modelIdToMainModelField(currentModel?.id || ''); if (!model) {
moduleLog.error('No model found in state');
throw new Error('No model found in state');
}
const use_cpu = shouldUseNoiseSettings const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise ? shouldUseCpuNoise
@ -295,51 +295,87 @@ export const buildLinearImageToImageGraph = (
}); });
} }
if (isBatchEnabled && asInitialImage && batchImageNames.length > 0) { // TODO: add batch functionality
// we are going to connect an iterate up to the init image // if (isBatchEnabled && asInitialImage && batchImageNames.length > 0) {
delete (graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image; // // we are going to connect an iterate up to the init image
// delete (graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image;
const imageCollection: ImageCollectionInvocation = { // const imageCollection: ImageCollectionInvocation = {
id: IMAGE_COLLECTION, // id: IMAGE_COLLECTION,
type: 'image_collection', // type: 'image_collection',
images: batchImageNames.map((image_name) => ({ image_name })), // images: batchImageNames.map((image_name) => ({ image_name })),
// };
// const imageCollectionIterate: IterateInvocation = {
// id: IMAGE_COLLECTION_ITERATE,
// type: 'iterate',
// };
// graph.nodes[IMAGE_COLLECTION] = imageCollection;
// graph.nodes[IMAGE_COLLECTION_ITERATE] = imageCollectionIterate;
// graph.edges.push({
// source: { node_id: IMAGE_COLLECTION, field: 'collection' },
// destination: {
// node_id: IMAGE_COLLECTION_ITERATE,
// field: 'collection',
// },
// });
// graph.edges.push({
// source: { node_id: IMAGE_COLLECTION_ITERATE, field: 'item' },
// destination: {
// node_id: IMAGE_TO_LATENTS,
// field: 'image',
// },
// });
// }
// add metadata accumulator, which is only mostly populated - some fields are added later
graph.nodes[METADATA_ACCUMULATOR] = {
id: METADATA_ACCUMULATOR,
type: 'metadata_accumulator',
generation_mode: 'img2img',
cfg_scale,
height,
width,
positive_prompt: '', // set in addDynamicPromptsToGraph
negative_prompt: negativePrompt,
model,
seed: 0, // set in addDynamicPromptsToGraph
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
scheduler,
vae: undefined, // option; set in addVAEToGraph
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
clip_skip: clipSkip,
strength,
init_image: initialImage.imageName,
}; };
const imageCollectionIterate: IterateInvocation = {
id: IMAGE_COLLECTION_ITERATE,
type: 'iterate',
};
graph.nodes[IMAGE_COLLECTION] = imageCollection;
graph.nodes[IMAGE_COLLECTION_ITERATE] = imageCollectionIterate;
graph.edges.push({ graph.edges.push({
source: { node_id: IMAGE_COLLECTION, field: 'collection' }, source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: { destination: {
node_id: IMAGE_COLLECTION_ITERATE, node_id: LATENTS_TO_IMAGE,
field: 'collection', field: 'metadata',
}, },
}); });
graph.edges.push({ // add LoRA support
source: { node_id: IMAGE_COLLECTION_ITERATE, field: 'item' }, addLoRAsToGraph(state, graph, LATENTS_TO_LATENTS);
destination: {
node_id: IMAGE_TO_LATENTS,
field: 'image',
},
});
}
addLoRAsToGraph(graph, state, LATENTS_TO_LATENTS); // optionally add custom VAE
addVAEToGraph(state, graph);
// Add VAE // add dynamic prompts - also sets up core iteration and seed
addVAEToGraph(graph, state); addDynamicPromptsToGraph(state, graph);
// add dynamic prompts, mutating `graph`
addDynamicPromptsToGraph(graph, state);
// add controlnet, mutating `graph` // add controlnet, mutating `graph`
addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state); addControlNetToLinearGraph(state, graph, LATENTS_TO_LATENTS);
return graph; return graph;
}; };

View File

@ -1,8 +1,8 @@
import { log } from 'app/logging/useLogger';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { initialGenerationState } from 'features/parameters/store/generationSlice'; import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
@ -10,6 +10,7 @@ import {
CLIP_SKIP, CLIP_SKIP,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
MAIN_MODEL_LOADER, MAIN_MODEL_LOADER,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
@ -17,13 +18,15 @@ import {
TEXT_TO_LATENTS, TEXT_TO_LATENTS,
} from './constants'; } from './constants';
const moduleLog = log.child({ namespace: 'nodes' });
export const buildLinearTextToImageGraph = ( export const buildLinearTextToImageGraph = (
state: RootState state: RootState
): NonNullableGraph => { ): NonNullableGraph => {
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: currentModel, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -34,12 +37,15 @@ export const buildLinearTextToImageGraph = (
shouldUseNoiseSettings, shouldUseNoiseSettings,
} = state.generation; } = state.generation;
const model = modelIdToMainModelField(currentModel?.id || '');
const use_cpu = shouldUseNoiseSettings const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise ? shouldUseCpuNoise
: initialGenerationState.shouldUseCpuNoise; : initialGenerationState.shouldUseCpuNoise;
if (!model) {
moduleLog.error('No model found in state');
throw new Error('No model found in state');
}
/** /**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node * full graph here as a template. Then use the parameters from app state and set friendlier node
@ -176,16 +182,49 @@ export const buildLinearTextToImageGraph = (
], ],
}; };
addLoRAsToGraph(graph, state, TEXT_TO_LATENTS); // add metadata accumulator, which is only mostly populated - some fields are added later
graph.nodes[METADATA_ACCUMULATOR] = {
id: METADATA_ACCUMULATOR,
type: 'metadata_accumulator',
generation_mode: 'txt2img',
cfg_scale,
height,
width,
positive_prompt: '', // set in addDynamicPromptsToGraph
negative_prompt: negativePrompt,
model,
seed: 0, // set in addDynamicPromptsToGraph
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
scheduler,
vae: undefined, // option; set in addVAEToGraph
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
clip_skip: clipSkip,
};
// Add Custom VAE Support graph.edges.push({
addVAEToGraph(graph, state); source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'metadata',
},
});
// add dynamic prompts, mutating `graph` // add LoRA support
addDynamicPromptsToGraph(graph, state); addLoRAsToGraph(state, graph, TEXT_TO_LATENTS);
// optionally add custom VAE
addVAEToGraph(state, graph);
// add dynamic prompts - also sets up core iteration and seed
addDynamicPromptsToGraph(state, graph);
// add controlnet, mutating `graph` // add controlnet, mutating `graph`
addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state); addControlNetToLinearGraph(state, graph, TEXT_TO_LATENTS);
return graph; return graph;
}; };

View File

@ -19,6 +19,7 @@ export const CONTROL_NET_COLLECT = 'control_net_collect';
export const DYNAMIC_PROMPT = 'dynamic_prompt'; export const DYNAMIC_PROMPT = 'dynamic_prompt';
export const IMAGE_COLLECTION = 'image_collection'; export const IMAGE_COLLECTION = 'image_collection';
export const IMAGE_COLLECTION_ITERATE = 'image_collection_iterate'; export const IMAGE_COLLECTION_ITERATE = 'image_collection_iterate';
export const METADATA_ACCUMULATOR = 'metadata_accumulator';
// friendly graph ids // friendly graph ids
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph'; export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';

View File

@ -5,17 +5,21 @@ import {
InputFieldTemplate, InputFieldTemplate,
InvocationSchemaObject, InvocationSchemaObject,
InvocationTemplate, InvocationTemplate,
isInvocationSchemaObject,
OutputFieldTemplate, OutputFieldTemplate,
isInvocationSchemaObject,
} from '../types/types'; } from '../types/types';
import { import {
buildInputFieldTemplate, buildInputFieldTemplate,
buildOutputFieldTemplates, buildOutputFieldTemplates,
} from './fieldTemplateBuilders'; } from './fieldTemplateBuilders';
const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate']; const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate', 'core_metadata'];
const invocationDenylist = ['Graph', 'InvocationMeta']; const invocationDenylist = [
'Graph',
'InvocationMeta',
'MetadataAccumulatorInvocation',
];
export const parseSchema = (openAPI: OpenAPIV3.Document) => { export const parseSchema = (openAPI: OpenAPIV3.Document) => {
// filter out non-invocation schemas, plus some tricky invocations for now // filter out non-invocation schemas, plus some tricky invocations for now

View File

@ -162,7 +162,7 @@ export const useRecallParameters = () => {
parameterNotSetToast(); parameterNotSetToast();
return; return;
} }
dispatch(modelSelected(model?.id || '')); dispatch(modelSelected(model));
parameterSetToast(); parameterSetToast();
}, },
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]

View File

@ -1,8 +1,10 @@
import { createAction } from '@reduxjs/toolkit'; import { createAction } from '@reduxjs/toolkit';
import { ImageDTO } from 'services/api/types'; import { ImageDTO, MainModelField } from 'services/api/types';
export const initialImageSelected = createAction<ImageDTO | string | undefined>( export const initialImageSelected = createAction<ImageDTO | string | undefined>(
'generation/initialImageSelected' 'generation/initialImageSelected'
); );
export const modelSelected = createAction<string>('generation/modelSelected'); export const modelSelected = createAction<MainModelField>(
'generation/modelSelected'
);

View File

@ -8,12 +8,11 @@ import {
setShouldShowAdvancedOptions, setShouldShowAdvancedOptions,
} from 'features/ui/store/uiSlice'; } from 'features/ui/store/uiSlice';
import { clamp } from 'lodash-es'; import { clamp } from 'lodash-es';
import { ImageDTO } from 'services/api/types'; import { ImageDTO, MainModelField } from 'services/api/types';
import { clipSkipMap } from '../components/Parameters/Advanced/ParamClipSkip'; import { clipSkipMap } from '../components/Parameters/Advanced/ParamClipSkip';
import { import {
CfgScaleParam, CfgScaleParam,
HeightParam, HeightParam,
MainModelParam,
NegativePromptParam, NegativePromptParam,
PositivePromptParam, PositivePromptParam,
SchedulerParam, SchedulerParam,
@ -54,7 +53,7 @@ export interface GenerationState {
shouldUseSymmetry: boolean; shouldUseSymmetry: boolean;
horizontalSymmetrySteps: number; horizontalSymmetrySteps: number;
verticalSymmetrySteps: number; verticalSymmetrySteps: number;
model: MainModelParam | null; model: MainModelField | null;
vae: VaeModelParam | null; vae: VaeModelParam | null;
seamlessXAxis: boolean; seamlessXAxis: boolean;
seamlessYAxis: boolean; seamlessYAxis: boolean;
@ -227,23 +226,17 @@ export const generationSlice = createSlice({
const { image_name, width, height } = action.payload; const { image_name, width, height } = action.payload;
state.initialImage = { imageName: image_name, width, height }; state.initialImage = { imageName: image_name, width, height };
}, },
modelSelected: (state, action: PayloadAction<string>) => { modelChanged: (state, action: PayloadAction<MainModelField | null>) => {
const [base_model, type, name] = action.payload.split('/'); if (!action.payload) {
state.model = null;
}
state.model = zMainModel.parse({ state.model = zMainModel.parse(action.payload);
id: action.payload,
base_model,
name,
type,
});
// Clamp ClipSkip Based On Selected Model // Clamp ClipSkip Based On Selected Model
const { maxClip } = clipSkipMap[state.model.base_model]; const { maxClip } = clipSkipMap[state.model.base_model];
state.clipSkip = clamp(state.clipSkip, 0, maxClip); state.clipSkip = clamp(state.clipSkip, 0, maxClip);
}, },
modelChanged: (state, action: PayloadAction<MainModelParam>) => {
state.model = action.payload;
},
vaeSelected: (state, action: PayloadAction<VaeModelParam | null>) => { vaeSelected: (state, action: PayloadAction<VaeModelParam | null>) => {
state.vae = action.payload; state.vae = action.payload;
}, },

View File

@ -135,8 +135,7 @@ export type BaseModelParam = z.infer<typeof zBaseModel>;
* TODO: Make this a dynamically generated enum? * TODO: Make this a dynamically generated enum?
*/ */
export const zMainModel = z.object({ export const zMainModel = z.object({
id: z.string(), model_name: z.string(),
name: z.string(),
base_model: zBaseModel, base_model: zBaseModel,
}); });

View File

@ -1,13 +1,16 @@
import { memo, useCallback, useEffect, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { SelectItem } from '@mantine/core'; import { SelectItem } from '@mantine/core';
import { RootState } from 'app/store/store'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { modelIdToMainModelField } from 'features/nodes/util/modelIdToMainModelField';
import { modelSelected } from 'features/parameters/store/actions'; import { modelSelected } from 'features/parameters/store/actions';
import { forEach, isString } from 'lodash-es'; import { forEach } from 'lodash-es';
import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models';
export const MODEL_TYPE_MAP = { export const MODEL_TYPE_MAP = {
@ -15,13 +18,17 @@ export const MODEL_TYPE_MAP = {
'sd-2': 'Stable Diffusion 2.x', 'sd-2': 'Stable Diffusion 2.x',
}; };
const selector = createSelector(
stateSelector,
(state) => ({ currentModel: state.generation.model }),
defaultSelectorOptions
);
const ModelSelect = () => { const ModelSelect = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const currentModel = useAppSelector( const { currentModel } = useAppSelector(selector);
(state: RootState) => state.generation.model
);
const { data: mainModels, isLoading } = useGetMainModelsQuery(); const { data: mainModels, isLoading } = useGetMainModelsQuery();
@ -39,7 +46,7 @@ const ModelSelect = () => {
data.push({ data.push({
value: id, value: id,
label: model.name, label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model], group: MODEL_TYPE_MAP[model.base_model],
}); });
}); });
@ -48,7 +55,10 @@ const ModelSelect = () => {
}, [mainModels]); }, [mainModels]);
const selectedModel = useMemo( const selectedModel = useMemo(
() => mainModels?.entities[currentModel?.id || ''], () =>
mainModels?.entities[
`${currentModel?.base_model}/main/${currentModel?.model_name}`
],
[mainModels?.entities, currentModel] [mainModels?.entities, currentModel]
); );
@ -57,31 +67,13 @@ const ModelSelect = () => {
if (!v) { if (!v) {
return; return;
} }
dispatch(modelSelected(v));
const modelField = modelIdToMainModelField(v);
dispatch(modelSelected(modelField));
}, },
[dispatch] [dispatch]
); );
useEffect(() => {
if (isLoading) {
// return early here to avoid resetting model selection before we've loaded the available models
return;
}
if (selectedModel && mainModels?.ids.includes(selectedModel?.id)) {
// the selected model is an available model, no need to change it
return;
}
const firstModel = mainModels?.ids[0];
if (!isString(firstModel)) {
return;
}
handleChangeModel(firstModel);
}, [handleChangeModel, isLoading, mainModels?.ids, selectedModel]);
return isLoading ? ( return isLoading ? (
<IAIMantineSelect <IAIMantineSelect
label={t('modelManager.model')} label={t('modelManager.model')}
@ -94,9 +86,10 @@ const ModelSelect = () => {
tooltip={selectedModel?.description} tooltip={selectedModel?.description}
label={t('modelManager.model')} label={t('modelManager.model')}
value={selectedModel?.id} value={selectedModel?.id}
placeholder={data.length > 0 ? 'Select a model' : 'No models detected!'} placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
data={data} data={data}
error={data.length === 0} error={data.length === 0}
disabled={data.length === 0}
onChange={handleChangeModel} onChange={handleChangeModel}
/> />
); );

View File

@ -50,7 +50,7 @@ const VAESelect = () => {
data.push({ data.push({
value: id, value: id,
label: model.name, label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model], group: MODEL_TYPE_MAP[model.base_model],
disabled, disabled,
tooltip: disabled tooltip: disabled

View File

@ -1,13 +1,22 @@
import { ApiFullTagDescription, api } from '..'; import { ApiFullTagDescription, api } from '..';
import { components } from '../schema';
import { ImageDTO } from '../types'; import { ImageDTO } from '../types';
/**
* This is an unsafe type; the object inside is not guaranteed to be valid.
*/
export type UnsafeImageMetadata = {
metadata: components['schemas']['CoreMetadata'];
graph: NonNullable<components['schemas']['Graph']>;
};
export const imagesApi = api.injectEndpoints({ export const imagesApi = api.injectEndpoints({
endpoints: (build) => ({ endpoints: (build) => ({
/** /**
* Image Queries * Image Queries
*/ */
getImageDTO: build.query<ImageDTO, string>({ getImageDTO: build.query<ImageDTO, string>({
query: (image_name) => ({ url: `images/${image_name}/metadata` }), query: (image_name) => ({ url: `images/${image_name}` }),
providesTags: (result, error, arg) => { providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [{ type: 'Image', id: arg }]; const tags: ApiFullTagDescription[] = [{ type: 'Image', id: arg }];
if (result?.board_id) { if (result?.board_id) {
@ -17,7 +26,17 @@ export const imagesApi = api.injectEndpoints({
}, },
keepUnusedDataFor: 86400, // 24 hours keepUnusedDataFor: 86400, // 24 hours
}), }),
getImageMetadata: build.query<UnsafeImageMetadata, string>({
query: (image_name) => ({ url: `images/${image_name}/metadata` }),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ type: 'ImageMetadata', id: arg },
];
return tags;
},
keepUnusedDataFor: 86400, // 24 hours
}),
}), }),
}); });
export const { useGetImageDTOQuery } = imagesApi; export const { useGetImageDTOQuery, useGetImageMetadataQuery } = imagesApi;

View File

@ -33,25 +33,28 @@ type AnyModelConfigEntity =
| VaeModelConfigEntity; | VaeModelConfigEntity;
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({ const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.name.localeCompare(b.name), sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
}); });
const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({ const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
sortComparer: (a, b) => a.name.localeCompare(b.name), sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
}); });
const controlNetModelsAdapter = const controlNetModelsAdapter =
createEntityAdapter<ControlNetModelConfigEntity>({ createEntityAdapter<ControlNetModelConfigEntity>({
sortComparer: (a, b) => a.name.localeCompare(b.name), sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
}); });
const textualInversionModelsAdapter = const textualInversionModelsAdapter =
createEntityAdapter<TextualInversionModelConfigEntity>({ createEntityAdapter<TextualInversionModelConfigEntity>({
sortComparer: (a, b) => a.name.localeCompare(b.name), sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
}); });
const vaeModelsAdapter = createEntityAdapter<VaeModelConfigEntity>({ const vaeModelsAdapter = createEntityAdapter<VaeModelConfigEntity>({
sortComparer: (a, b) => a.name.localeCompare(b.name), sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
}); });
export const getModelId = ({ base_model, type, name }: AnyModelConfig) => export const getModelId = ({
`${base_model}/${type}/${name}`; base_model,
model_type,
model_name,
}: AnyModelConfig) => `${base_model}/${model_type}/${model_name}`;
const createModelEntities = <T extends AnyModelConfigEntity>( const createModelEntities = <T extends AnyModelConfigEntity>(
models: AnyModelConfig[] models: AnyModelConfig[]

View File

@ -1,3 +1,4 @@
import { FullTagDescription } from '@reduxjs/toolkit/dist/query/endpointDefinitions';
import { import {
BaseQueryFn, BaseQueryFn,
FetchArgs, FetchArgs,
@ -5,10 +6,9 @@ import {
createApi, createApi,
fetchBaseQuery, fetchBaseQuery,
} from '@reduxjs/toolkit/query/react'; } from '@reduxjs/toolkit/query/react';
import { FullTagDescription } from '@reduxjs/toolkit/dist/query/endpointDefinitions';
import { $authToken, $baseUrl } from 'services/api/client'; import { $authToken, $baseUrl } from 'services/api/client';
export const tagTypes = ['Board', 'Image', 'Model']; export const tagTypes = ['Board', 'Image', 'ImageMetadata', 'Model'];
export type ApiFullTagDescription = FullTagDescription< export type ApiFullTagDescription = FullTagDescription<
(typeof tagTypes)[number] (typeof tagTypes)[number]
>; >;

View File

@ -1,9 +1,9 @@
import queryString from 'query-string';
import { createAppAsyncThunk } from 'app/store/storeUtils'; import { createAppAsyncThunk } from 'app/store/storeUtils';
import { selectImagesAll } from 'features/gallery/store/gallerySlice'; import { selectImagesAll } from 'features/gallery/store/gallerySlice';
import { size } from 'lodash-es'; import { size } from 'lodash-es';
import { paths } from 'services/api/schema'; import queryString from 'query-string';
import { $client } from 'services/api/client'; import { $client } from 'services/api/client';
import { paths } from 'services/api/schema';
type GetImageUrlsArg = type GetImageUrlsArg =
paths['/api/v1/images/{image_name}/urls']['get']['parameters']['path']; paths['/api/v1/images/{image_name}/urls']['get']['parameters']['path'];
@ -24,7 +24,7 @@ export const imageUrlsReceived = createAppAsyncThunk<
GetImageUrlsResponse, GetImageUrlsResponse,
GetImageUrlsArg, GetImageUrlsArg,
GetImageUrlsThunkConfig GetImageUrlsThunkConfig
>('api/imageUrlsReceived', async (arg, { rejectWithValue }) => { >('thunkApi/imageUrlsReceived', async (arg, { rejectWithValue }) => {
const { image_name } = arg; const { image_name } = arg;
const { get } = $client.get(); const { get } = $client.get();
const { data, error, response } = await get( const { data, error, response } = await get(
@ -46,10 +46,10 @@ export const imageUrlsReceived = createAppAsyncThunk<
}); });
type GetImageMetadataArg = type GetImageMetadataArg =
paths['/api/v1/images/{image_name}/metadata']['get']['parameters']['path']; paths['/api/v1/images/{image_name}']['get']['parameters']['path'];
type GetImageMetadataResponse = type GetImageMetadataResponse =
paths['/api/v1/images/{image_name}/metadata']['get']['responses']['200']['content']['application/json']; paths['/api/v1/images/{image_name}']['get']['responses']['200']['content']['application/json'];
type GetImageMetadataThunkConfig = { type GetImageMetadataThunkConfig = {
rejectValue: { rejectValue: {
@ -58,21 +58,18 @@ type GetImageMetadataThunkConfig = {
}; };
}; };
export const imageMetadataReceived = createAppAsyncThunk< export const imageDTOReceived = createAppAsyncThunk<
GetImageMetadataResponse, GetImageMetadataResponse,
GetImageMetadataArg, GetImageMetadataArg,
GetImageMetadataThunkConfig GetImageMetadataThunkConfig
>('api/imageMetadataReceived', async (arg, { rejectWithValue }) => { >('thunkApi/imageMetadataReceived', async (arg, { rejectWithValue }) => {
const { image_name } = arg; const { image_name } = arg;
const { get } = $client.get(); const { get } = $client.get();
const { data, error, response } = await get( const { data, error, response } = await get('/api/v1/images/{image_name}', {
'/api/v1/images/{image_name}/metadata',
{
params: { params: {
path: { image_name }, path: { image_name },
}, },
} });
);
if (error) { if (error) {
return rejectWithValue({ arg, error }); return rejectWithValue({ arg, error });
@ -148,7 +145,7 @@ export const imageUploaded = createAppAsyncThunk<
UploadImageResponse, UploadImageResponse,
UploadImageArg, UploadImageArg,
UploadImageThunkConfig UploadImageThunkConfig
>('api/imageUploaded', async (arg, { rejectWithValue }) => { >('thunkApi/imageUploaded', async (arg, { rejectWithValue }) => {
const { const {
postUploadAction, postUploadAction,
file, file,
@ -199,7 +196,7 @@ export const imageDeleted = createAppAsyncThunk<
DeleteImageResponse, DeleteImageResponse,
DeleteImageArg, DeleteImageArg,
DeleteImageThunkConfig DeleteImageThunkConfig
>('api/imageDeleted', async (arg, { rejectWithValue }) => { >('thunkApi/imageDeleted', async (arg, { rejectWithValue }) => {
const { image_name } = arg; const { image_name } = arg;
const { del } = $client.get(); const { del } = $client.get();
const { data, error, response } = await del('/api/v1/images/{image_name}', { const { data, error, response } = await del('/api/v1/images/{image_name}', {
@ -235,7 +232,7 @@ export const imageUpdated = createAppAsyncThunk<
UpdateImageResponse, UpdateImageResponse,
UpdateImageArg, UpdateImageArg,
UpdateImageThunkConfig UpdateImageThunkConfig
>('api/imageUpdated', async (arg, { rejectWithValue }) => { >('thunkApi/imageUpdated', async (arg, { rejectWithValue }) => {
const { image_name, image_category, is_intermediate, session_id } = arg; const { image_name, image_category, is_intermediate, session_id } = arg;
const { patch } = $client.get(); const { patch } = $client.get();
const { data, error, response } = await patch('/api/v1/images/{image_name}', { const { data, error, response } = await patch('/api/v1/images/{image_name}', {
@ -284,7 +281,9 @@ export const receivedPageOfImages = createAppAsyncThunk<
ListImagesResponse, ListImagesResponse,
ListImagesArg, ListImagesArg,
ListImagesThunkConfig ListImagesThunkConfig
>('api/receivedPageOfImages', async (arg, { getState, rejectWithValue }) => { >(
'thunkApi/receivedPageOfImages',
async (arg, { getState, rejectWithValue }) => {
const { get } = $client.get(); const { get } = $client.get();
const state = getState(); const state = getState();
@ -326,4 +325,5 @@ export const receivedPageOfImages = createAppAsyncThunk<
} }
return data; return data;
}); }
);

View File

@ -19,6 +19,7 @@ export type ImageChanges = components['schemas']['ImageRecordChanges'];
export type ImageCategory = components['schemas']['ImageCategory']; export type ImageCategory = components['schemas']['ImageCategory'];
export type ResourceOrigin = components['schemas']['ResourceOrigin']; export type ResourceOrigin = components['schemas']['ResourceOrigin'];
export type ImageField = components['schemas']['ImageField']; export type ImageField = components['schemas']['ImageField'];
export type ImageMetadata = components['schemas']['ImageMetadata'];
export type OffsetPaginatedResults_BoardDTO_ = export type OffsetPaginatedResults_BoardDTO_ =
components['schemas']['OffsetPaginatedResults_BoardDTO_']; components['schemas']['OffsetPaginatedResults_BoardDTO_'];
export type OffsetPaginatedResults_ImageDTO_ = export type OffsetPaginatedResults_ImageDTO_ =
@ -31,6 +32,7 @@ export type MainModelField = components['schemas']['MainModelField'];
export type VAEModelField = components['schemas']['VAEModelField']; export type VAEModelField = components['schemas']['VAEModelField'];
export type LoRAModelField = components['schemas']['LoRAModelField']; export type LoRAModelField = components['schemas']['LoRAModelField'];
export type ModelsList = components['schemas']['ModelsList']; export type ModelsList = components['schemas']['ModelsList'];
export type ControlField = components['schemas']['ControlField'];
// Model Configs // Model Configs
export type LoRAModelConfig = components['schemas']['LoRAModelConfig']; export type LoRAModelConfig = components['schemas']['LoRAModelConfig'];
@ -107,6 +109,9 @@ export type MainModelLoaderInvocation = TypeReq<
export type LoraLoaderInvocation = TypeReq< export type LoraLoaderInvocation = TypeReq<
components['schemas']['LoraLoaderInvocation'] components['schemas']['LoraLoaderInvocation']
>; >;
export type MetadataAccumulatorInvocation = TypeReq<
components['schemas']['MetadataAccumulatorInvocation']
>;
// ControlNet Nodes // ControlNet Nodes
export type ControlNetInvocation = TypeReq< export type ControlNetInvocation = TypeReq<