mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
resolve conflicts
This commit is contained in:
@ -2,22 +2,26 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { roundToMultiple } from 'common/util/roundDownToMultiple';
|
||||
import {
|
||||
canvasSelector,
|
||||
isStagingSelector,
|
||||
} from 'features/canvas/store/canvasSelectors';
|
||||
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import { memo } from 'react';
|
||||
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selector = createSelector(
|
||||
[canvasSelector, isStagingSelector],
|
||||
(canvas, isStaging) => {
|
||||
[canvasSelector, isStagingSelector, uiSelector],
|
||||
(canvas, isStaging, ui) => {
|
||||
const { boundingBoxDimensions } = canvas;
|
||||
const { aspectRatio } = ui;
|
||||
return {
|
||||
boundingBoxDimensions,
|
||||
isStaging,
|
||||
aspectRatio,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
@ -25,7 +29,8 @@ const selector = createSelector(
|
||||
|
||||
const ParamBoundingBoxWidth = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { boundingBoxDimensions, isStaging } = useAppSelector(selector);
|
||||
const { boundingBoxDimensions, isStaging, aspectRatio } =
|
||||
useAppSelector(selector);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
@ -36,6 +41,15 @@ const ParamBoundingBoxWidth = () => {
|
||||
height: Math.floor(v),
|
||||
})
|
||||
);
|
||||
if (aspectRatio) {
|
||||
const newWidth = roundToMultiple(v * aspectRatio, 64);
|
||||
dispatch(
|
||||
setBoundingBoxDimensions({
|
||||
width: newWidth,
|
||||
height: Math.floor(v),
|
||||
})
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
const handleResetHeight = () => {
|
||||
@ -45,6 +59,15 @@ const ParamBoundingBoxWidth = () => {
|
||||
height: Math.floor(512),
|
||||
})
|
||||
);
|
||||
if (aspectRatio) {
|
||||
const newWidth = roundToMultiple(512 * aspectRatio, 64);
|
||||
dispatch(
|
||||
setBoundingBoxDimensions({
|
||||
width: newWidth,
|
||||
height: Math.floor(512),
|
||||
})
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
|
@ -0,0 +1,57 @@
|
||||
import { Flex, Spacer, Text } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { flipBoundingBoxAxes } from 'features/canvas/store/canvasSlice';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { MdOutlineSwapVert } from 'react-icons/md';
|
||||
import ParamAspectRatio from '../../Core/ParamAspectRatio';
|
||||
import ParamBoundingBoxHeight from './ParamBoundingBoxHeight';
|
||||
import ParamBoundingBoxWidth from './ParamBoundingBoxWidth';
|
||||
|
||||
export default function ParamBoundingBoxSize() {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
gap: 2,
|
||||
p: 4,
|
||||
borderRadius: 4,
|
||||
flexDirection: 'column',
|
||||
w: 'full',
|
||||
bg: 'base.150',
|
||||
_dark: {
|
||||
bg: 'base.750',
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Flex alignItems="center" gap={2}>
|
||||
<Text
|
||||
sx={{
|
||||
fontSize: 'sm',
|
||||
width: 'full',
|
||||
color: 'base.700',
|
||||
_dark: {
|
||||
color: 'base.300',
|
||||
},
|
||||
}}
|
||||
>
|
||||
{t('parameters.aspectRatio')}
|
||||
</Text>
|
||||
<Spacer />
|
||||
<ParamAspectRatio />
|
||||
<IAIIconButton
|
||||
tooltip={t('ui.swapSizes')}
|
||||
aria-label={t('ui.swapSizes')}
|
||||
size="sm"
|
||||
icon={<MdOutlineSwapVert />}
|
||||
fontSize={20}
|
||||
onClick={() => dispatch(flipBoundingBoxAxes())}
|
||||
/>
|
||||
</Flex>
|
||||
<ParamBoundingBoxWidth />
|
||||
<ParamBoundingBoxHeight />
|
||||
</Flex>
|
||||
);
|
||||
}
|
@ -2,22 +2,26 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { roundToMultiple } from 'common/util/roundDownToMultiple';
|
||||
import {
|
||||
canvasSelector,
|
||||
isStagingSelector,
|
||||
} from 'features/canvas/store/canvasSelectors';
|
||||
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import { memo } from 'react';
|
||||
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selector = createSelector(
|
||||
[canvasSelector, isStagingSelector],
|
||||
(canvas, isStaging) => {
|
||||
[canvasSelector, isStagingSelector, uiSelector],
|
||||
(canvas, isStaging, ui) => {
|
||||
const { boundingBoxDimensions } = canvas;
|
||||
const { aspectRatio } = ui;
|
||||
return {
|
||||
boundingBoxDimensions,
|
||||
isStaging,
|
||||
aspectRatio,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
@ -25,7 +29,8 @@ const selector = createSelector(
|
||||
|
||||
const ParamBoundingBoxWidth = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { boundingBoxDimensions, isStaging } = useAppSelector(selector);
|
||||
const { boundingBoxDimensions, isStaging, aspectRatio } =
|
||||
useAppSelector(selector);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
@ -36,6 +41,15 @@ const ParamBoundingBoxWidth = () => {
|
||||
width: Math.floor(v),
|
||||
})
|
||||
);
|
||||
if (aspectRatio) {
|
||||
const newHeight = roundToMultiple(v / aspectRatio, 64);
|
||||
dispatch(
|
||||
setBoundingBoxDimensions({
|
||||
width: Math.floor(v),
|
||||
height: newHeight,
|
||||
})
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
const handleResetWidth = () => {
|
||||
@ -45,6 +59,15 @@ const ParamBoundingBoxWidth = () => {
|
||||
width: Math.floor(512),
|
||||
})
|
||||
);
|
||||
if (aspectRatio) {
|
||||
const newHeight = roundToMultiple(512 / aspectRatio, 64);
|
||||
dispatch(
|
||||
setBoundingBoxDimensions({
|
||||
width: Math.floor(512),
|
||||
height: newHeight,
|
||||
})
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
|
@ -1,23 +1,21 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import { setInfillMethod } from 'features/parameters/store/generationSlice';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetAppConfigQuery } from 'services/api/endpoints/appInfo';
|
||||
|
||||
const selector = createSelector(
|
||||
[generationSelector, systemSelector],
|
||||
(parameters, system) => {
|
||||
const { infillMethod } = parameters;
|
||||
const { infillMethods } = system;
|
||||
[stateSelector],
|
||||
({ generation }) => {
|
||||
const { infillMethod } = generation;
|
||||
|
||||
return {
|
||||
infillMethod,
|
||||
infillMethods,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
@ -25,7 +23,11 @@ const selector = createSelector(
|
||||
|
||||
const ParamInfillMethod = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { infillMethod, infillMethods } = useAppSelector(selector);
|
||||
const { infillMethod } = useAppSelector(selector);
|
||||
|
||||
const { data: appConfigData, isLoading } = useGetAppConfigQuery();
|
||||
|
||||
const infill_methods = appConfigData?.infill_methods;
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
@ -38,9 +40,11 @@ const ParamInfillMethod = () => {
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
disabled={infill_methods?.length === 0}
|
||||
placeholder={isLoading ? 'Loading...' : undefined}
|
||||
label={t('parameters.infillMethod')}
|
||||
value={infillMethod}
|
||||
data={infillMethods}
|
||||
data={infill_methods ?? []}
|
||||
onChange={handleChange}
|
||||
/>
|
||||
);
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||
import { setBoundingBoxScaleMethod } from 'features/canvas/store/canvasSlice';
|
||||
import {
|
||||
@ -35,7 +35,7 @@ const ParamScaleBeforeProcessing = () => {
|
||||
};
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
<IAIMantineSearchableSelect
|
||||
label={t('parameters.scaleBeforeProcessing')}
|
||||
data={BOUNDING_BOX_SCALES_DICT}
|
||||
value={boundingBoxScale}
|
||||
|
@ -2,12 +2,13 @@ import { Divider, Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAICollapse from 'common/components/IAICollapse';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import ControlNet from 'features/controlNet/components/ControlNet';
|
||||
import ParamControlNetFeatureToggle from 'features/controlNet/components/parameters/ParamControlNetFeatureToggle';
|
||||
import {
|
||||
controlNetAdded,
|
||||
controlNetModelChanged,
|
||||
controlNetSelector,
|
||||
} from 'features/controlNet/store/controlNetSlice';
|
||||
import { getValidControlNets } from 'features/controlNet/util/getValidControlNets';
|
||||
@ -15,6 +16,8 @@ import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { map } from 'lodash-es';
|
||||
import { Fragment, memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaPlus } from 'react-icons/fa';
|
||||
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
const selector = createSelector(
|
||||
@ -39,10 +42,23 @@ const ParamControlNetCollapse = () => {
|
||||
const { controlNetsArray, activeLabel } = useAppSelector(selector);
|
||||
const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled;
|
||||
const dispatch = useAppDispatch();
|
||||
const { firstModel } = useGetControlNetModelsQuery(undefined, {
|
||||
selectFromResult: (result) => {
|
||||
const firstModel = result.data?.entities[result.data?.ids[0]];
|
||||
return {
|
||||
firstModel,
|
||||
};
|
||||
},
|
||||
});
|
||||
|
||||
const handleClickedAddControlNet = useCallback(() => {
|
||||
dispatch(controlNetAdded({ controlNetId: uuidv4() }));
|
||||
}, [dispatch]);
|
||||
if (!firstModel) {
|
||||
return;
|
||||
}
|
||||
const controlNetId = uuidv4();
|
||||
dispatch(controlNetAdded({ controlNetId }));
|
||||
dispatch(controlNetModelChanged({ controlNetId, model: firstModel }));
|
||||
}, [dispatch, firstModel]);
|
||||
|
||||
if (isControlNetDisabled) {
|
||||
return null;
|
||||
@ -51,16 +67,39 @@ const ParamControlNetCollapse = () => {
|
||||
return (
|
||||
<IAICollapse label="ControlNet" activeLabel={activeLabel}>
|
||||
<Flex sx={{ flexDir: 'column', gap: 3 }}>
|
||||
<ParamControlNetFeatureToggle />
|
||||
<Flex gap={2} alignItems="center">
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
w: '100%',
|
||||
gap: 2,
|
||||
px: 4,
|
||||
py: 2,
|
||||
borderRadius: 4,
|
||||
bg: 'base.200',
|
||||
_dark: {
|
||||
bg: 'base.850',
|
||||
},
|
||||
}}
|
||||
>
|
||||
<ParamControlNetFeatureToggle />
|
||||
</Flex>
|
||||
<IAIIconButton
|
||||
tooltip="Add ControlNet"
|
||||
aria-label="Add ControlNet"
|
||||
icon={<FaPlus />}
|
||||
isDisabled={!firstModel}
|
||||
flexGrow={1}
|
||||
size="md"
|
||||
onClick={handleClickedAddControlNet}
|
||||
/>
|
||||
</Flex>
|
||||
{controlNetsArray.map((c, i) => (
|
||||
<Fragment key={c.controlNetId}>
|
||||
{i > 0 && <Divider />}
|
||||
<ControlNet controlNet={c} />
|
||||
<ControlNet controlNetId={c.controlNetId} />
|
||||
</Fragment>
|
||||
))}
|
||||
<IAIButton flexGrow={1} onClick={handleClickedAddControlNet}>
|
||||
Add ControlNet
|
||||
</IAIButton>
|
||||
</Flex>
|
||||
</IAICollapse>
|
||||
);
|
||||
|
@ -3,6 +3,7 @@ import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import { setAspectRatio } from 'features/ui/store/uiSlice';
|
||||
import { activeTabNameSelector } from '../../../../ui/store/uiSelectors';
|
||||
|
||||
const aspectRatios = [
|
||||
{ name: 'Free', value: null },
|
||||
@ -17,6 +18,10 @@ export default function ParamAspectRatio() {
|
||||
);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const shouldFitToWidthHeight = useAppSelector(
|
||||
(state: RootState) => state.generation.shouldFitToWidthHeight
|
||||
);
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
|
||||
return (
|
||||
<Flex gap={2} flexGrow={1}>
|
||||
@ -26,6 +31,9 @@ export default function ParamAspectRatio() {
|
||||
key={ratio.name}
|
||||
size="sm"
|
||||
isChecked={aspectRatio === ratio.value}
|
||||
isDisabled={
|
||||
activeTabName === 'img2img' ? !shouldFitToWidthHeight : false
|
||||
}
|
||||
onClick={() => dispatch(setAspectRatio(ratio.value))}
|
||||
>
|
||||
{ratio.name}
|
||||
|
@ -1,8 +1,8 @@
|
||||
import { Box, Flex } from '@chakra-ui/react';
|
||||
import ModelSelect from 'features/system/components/ModelSelect';
|
||||
import VAESelect from 'features/system/components/VAESelect';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { memo } from 'react';
|
||||
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
|
||||
import ParamMainModelSelect from '../MainModel/ParamMainModelSelect';
|
||||
import ParamVAEModelSelect from '../VAEModel/ParamVAEModelSelect';
|
||||
import ParamScheduler from './ParamScheduler';
|
||||
|
||||
const ParamModelandVAEandScheduler = () => {
|
||||
@ -11,12 +11,12 @@ const ParamModelandVAEandScheduler = () => {
|
||||
return (
|
||||
<Flex gap={3} w="full" flexWrap={isVaeEnabled ? 'wrap' : 'nowrap'}>
|
||||
<Box w="full">
|
||||
<ModelSelect />
|
||||
<ParamMainModelSelect />
|
||||
</Box>
|
||||
<Flex gap={3} w="full">
|
||||
{isVaeEnabled && (
|
||||
<Box w="full">
|
||||
<VAESelect />
|
||||
<ParamVAEModelSelect />
|
||||
</Box>
|
||||
)}
|
||||
<Box w="full">
|
||||
|
@ -2,10 +2,10 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import { setScheduler } from 'features/parameters/store/generationSlice';
|
||||
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
||||
import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -48,7 +48,7 @@ const ParamScheduler = () => {
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
<IAIMantineSearchableSelect
|
||||
label={t('parameters.scheduler')}
|
||||
value={scheduler}
|
||||
data={data}
|
||||
|
@ -8,6 +8,7 @@ import { MdOutlineSwapVert } from 'react-icons/md';
|
||||
import ParamAspectRatio from './ParamAspectRatio';
|
||||
import ParamHeight from './ParamHeight';
|
||||
import ParamWidth from './ParamWidth';
|
||||
import { activeTabNameSelector } from '../../../../ui/store/uiSelectors';
|
||||
|
||||
export default function ParamSize() {
|
||||
const { t } = useTranslation();
|
||||
@ -15,6 +16,7 @@ export default function ParamSize() {
|
||||
const shouldFitToWidthHeight = useAppSelector(
|
||||
(state: RootState) => state.generation.shouldFitToWidthHeight
|
||||
);
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
@ -50,13 +52,24 @@ export default function ParamSize() {
|
||||
size="sm"
|
||||
icon={<MdOutlineSwapVert />}
|
||||
fontSize={20}
|
||||
isDisabled={
|
||||
activeTabName === 'img2img' ? !shouldFitToWidthHeight : false
|
||||
}
|
||||
onClick={() => dispatch(toggleSize())}
|
||||
/>
|
||||
</Flex>
|
||||
<Flex gap={2} alignItems="center">
|
||||
<Flex gap={2} flexDirection="column" width="full">
|
||||
<ParamWidth isDisabled={!shouldFitToWidthHeight} />
|
||||
<ParamHeight isDisabled={!shouldFitToWidthHeight} />
|
||||
<ParamWidth
|
||||
isDisabled={
|
||||
activeTabName === 'img2img' ? !shouldFitToWidthHeight : false
|
||||
}
|
||||
/>
|
||||
<ParamHeight
|
||||
isDisabled={
|
||||
activeTabName === 'img2img' ? !shouldFitToWidthHeight : false
|
||||
}
|
||||
/>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { FACETOOL_TYPES } from 'app/constants';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import {
|
||||
FacetoolType,
|
||||
setFacetoolType,
|
||||
@ -20,7 +20,7 @@ export default function FaceRestoreType() {
|
||||
dispatch(setFacetoolType(v as FacetoolType));
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
<IAIMantineSearchableSelect
|
||||
label={t('parameters.type')}
|
||||
data={FACETOOL_TYPES.concat()}
|
||||
value={facetoolType}
|
||||
|
@ -1,30 +1,24 @@
|
||||
import { Flex, Icon, Text } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useMemo } from 'react';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import { FaImage } from 'react-icons/fa';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import {
|
||||
TypesafeDraggableData,
|
||||
TypesafeDroppableData,
|
||||
} from 'app/components/ImageDnd/typesafeDnd';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { useMemo } from 'react';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
(state) => {
|
||||
const { initialImage } = state.generation;
|
||||
const { asInitialImage: useBatchAsInitialImage, imageNames } = state.batch;
|
||||
return {
|
||||
initialImage,
|
||||
useBatchAsInitialImage,
|
||||
isResetButtonDisabled: useBatchAsInitialImage
|
||||
? imageNames.length === 0
|
||||
: !initialImage,
|
||||
isResetButtonDisabled: !initialImage,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
|
@ -1,22 +1,14 @@
|
||||
import { Flex, Spacer, Text } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { FaLayerGroup, FaUndo, FaUpload } from 'react-icons/fa';
|
||||
import useImageUploader from 'common/hooks/useImageUploader';
|
||||
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import {
|
||||
asInitialImageToggled,
|
||||
batchReset,
|
||||
} from 'features/batch/store/batchSlice';
|
||||
import BatchImageContainer from 'features/batch/components/BatchImageContainer';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
|
||||
import useImageUploader from 'common/hooks/useImageUploader';
|
||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||
import { useCallback } from 'react';
|
||||
import { FaUndo, FaUpload } from 'react-icons/fa';
|
||||
import { PostUploadAction } from 'services/api/thunks/image';
|
||||
import InitialImage from './InitialImage';
|
||||
|
||||
@ -24,59 +16,34 @@ const selector = createSelector(
|
||||
[stateSelector],
|
||||
(state) => {
|
||||
const { initialImage } = state.generation;
|
||||
const { asInitialImage: useBatchAsInitialImage, imageNames } = state.batch;
|
||||
return {
|
||||
initialImage,
|
||||
useBatchAsInitialImage,
|
||||
isResetButtonDisabled: useBatchAsInitialImage
|
||||
? imageNames.length === 0
|
||||
: !initialImage,
|
||||
isResetButtonDisabled: !initialImage,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const postUploadAction: PostUploadAction = {
|
||||
type: 'SET_INITIAL_IMAGE',
|
||||
};
|
||||
|
||||
const InitialImageDisplay = () => {
|
||||
const { initialImage, useBatchAsInitialImage, isResetButtonDisabled } =
|
||||
useAppSelector(selector);
|
||||
const { isResetButtonDisabled } = useAppSelector(selector);
|
||||
const dispatch = useAppDispatch();
|
||||
const { openUploader } = useImageUploader();
|
||||
|
||||
const {
|
||||
currentData: imageDTO,
|
||||
isLoading,
|
||||
isError,
|
||||
isSuccess,
|
||||
} = useGetImageDTOQuery(initialImage?.imageName ?? skipToken);
|
||||
|
||||
const postUploadAction = useMemo<PostUploadAction>(
|
||||
() =>
|
||||
useBatchAsInitialImage
|
||||
? { type: 'ADD_TO_BATCH' }
|
||||
: { type: 'SET_INITIAL_IMAGE' },
|
||||
[useBatchAsInitialImage]
|
||||
);
|
||||
|
||||
const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({
|
||||
postUploadAction,
|
||||
});
|
||||
|
||||
const handleReset = useCallback(() => {
|
||||
if (useBatchAsInitialImage) {
|
||||
dispatch(batchReset());
|
||||
} else {
|
||||
dispatch(clearInitialImage());
|
||||
}
|
||||
}, [dispatch, useBatchAsInitialImage]);
|
||||
dispatch(clearInitialImage());
|
||||
}, [dispatch]);
|
||||
|
||||
const handleUpload = useCallback(() => {
|
||||
openUploader();
|
||||
}, [openUploader]);
|
||||
|
||||
const handleClickUseBatch = useCallback(() => {
|
||||
dispatch(asInitialImageToggled());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
layerStyle={'first'}
|
||||
@ -114,40 +81,22 @@ const InitialImageDisplay = () => {
|
||||
Initial Image
|
||||
</Text>
|
||||
<Spacer />
|
||||
{/* <IAIButton
|
||||
tooltip={useBatchAsInitialImage ? 'Disable Batch' : 'Enable Batch'}
|
||||
aria-label={useBatchAsInitialImage ? 'Disable Batch' : 'Enable Batch'}
|
||||
leftIcon={<FaLayerGroup />}
|
||||
isChecked={useBatchAsInitialImage}
|
||||
onClick={handleClickUseBatch}
|
||||
>
|
||||
{useBatchAsInitialImage ? 'Batch' : 'Single'}
|
||||
</IAIButton> */}
|
||||
<IAIIconButton
|
||||
tooltip={
|
||||
useBatchAsInitialImage ? 'Upload to Batch' : 'Upload Initial Image'
|
||||
}
|
||||
aria-label={
|
||||
useBatchAsInitialImage ? 'Upload to Batch' : 'Upload Initial Image'
|
||||
}
|
||||
tooltip={'Upload Initial Image'}
|
||||
aria-label={'Upload Initial Image'}
|
||||
icon={<FaUpload />}
|
||||
onClick={handleUpload}
|
||||
{...getUploadButtonProps()}
|
||||
/>
|
||||
<IAIIconButton
|
||||
tooltip={
|
||||
useBatchAsInitialImage ? 'Reset Batch' : 'Reset Initial Image'
|
||||
}
|
||||
aria-label={
|
||||
useBatchAsInitialImage ? 'Reset Batch' : 'Reset Initial Image'
|
||||
}
|
||||
tooltip={'Reset Initial Image'}
|
||||
aria-label={'Reset Initial Image'}
|
||||
icon={<FaUndo />}
|
||||
onClick={handleReset}
|
||||
isDisabled={isResetButtonDisabled}
|
||||
/>
|
||||
</Flex>
|
||||
<InitialImage />
|
||||
{/* {useBatchAsInitialImage ? <BatchImageContainer /> : <InitialImage />} */}
|
||||
<input {...getUploadInputProps()} />
|
||||
</Flex>
|
||||
);
|
||||
|
@ -0,0 +1,100 @@
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
(state) => ({ model: state.generation.model }),
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamMainModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { model } = useAppSelector(selector);
|
||||
|
||||
const { data: mainModels, isLoading } = useGetMainModelsQuery();
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!mainModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(mainModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.model_name,
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [mainModels]);
|
||||
|
||||
// grab the full model entity from the RTK Query cache
|
||||
// TODO: maybe we should just store the full model entity in state?
|
||||
const selectedModel = useMemo(
|
||||
() =>
|
||||
mainModels?.entities[`${model?.base_model}/main/${model?.model_name}`] ??
|
||||
null,
|
||||
[mainModels?.entities, model]
|
||||
);
|
||||
|
||||
const handleChangeModel = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
|
||||
const newModel = modelIdToMainModelParam(v);
|
||||
|
||||
if (!newModel) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(modelSelected(newModel));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return isLoading ? (
|
||||
<IAIMantineSearchableSelect
|
||||
label={t('modelManager.model')}
|
||||
placeholder="Loading..."
|
||||
disabled={true}
|
||||
data={[]}
|
||||
/>
|
||||
) : (
|
||||
<IAIMantineSearchableSelect
|
||||
tooltip={selectedModel?.description}
|
||||
label={t('modelManager.model')}
|
||||
value={selectedModel?.id}
|
||||
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
|
||||
data={data}
|
||||
error={data.length === 0}
|
||||
disabled={data.length === 0}
|
||||
onChange={handleChangeModel}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamMainModelSelect);
|
@ -27,6 +27,9 @@ const ParamNoiseCollapse = () => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const isNoiseEnabled = useFeatureStatus('noise').isFeatureEnabled;
|
||||
const isPerlinNoiseEnabled = useFeatureStatus('perlinNoise').isFeatureEnabled;
|
||||
const isNoiseThresholdEnabled =
|
||||
useFeatureStatus('noiseThreshold').isFeatureEnabled;
|
||||
|
||||
const { activeLabel } = useAppSelector(selector);
|
||||
|
||||
@ -42,8 +45,8 @@ const ParamNoiseCollapse = () => {
|
||||
<Flex sx={{ gap: 2, flexDirection: 'column' }}>
|
||||
<ParamNoiseToggle />
|
||||
<ParamCpuNoiseToggle />
|
||||
<ParamPerlinNoise />
|
||||
<ParamNoiseThreshold />
|
||||
{isPerlinNoiseEnabled && <ParamPerlinNoise />}
|
||||
{isNoiseThresholdEnabled && <ParamNoiseThreshold />}
|
||||
</Flex>
|
||||
</IAICollapse>
|
||||
);
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { UPSCALING_LEVELS } from 'app/constants';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import {
|
||||
UpscalingLevel,
|
||||
setUpscalingLevel,
|
||||
@ -24,7 +24,7 @@ export default function UpscaleScale() {
|
||||
dispatch(setUpscalingLevel(Number(v) as UpscalingLevel));
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
<IAIMantineSearchableSelect
|
||||
disabled={!isESRGANAvailable}
|
||||
label={t('parameters.scale')}
|
||||
value={String(upscalingLevel)}
|
||||
|
@ -0,0 +1,110 @@
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||
import { vaeSelected } from 'features/parameters/store/generationSlice';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
({ generation }) => {
|
||||
const { model, vae } = generation;
|
||||
return { model, vae };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamVAEModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const { model, vae } = useAppSelector(selector);
|
||||
|
||||
const { data: vaeModels } = useGetVaeModelsQuery();
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!vaeModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
// add a "default" option, this means use the main model's included VAE
|
||||
const data: SelectItem[] = [
|
||||
{
|
||||
value: 'default',
|
||||
label: 'Default',
|
||||
group: 'Default',
|
||||
},
|
||||
];
|
||||
|
||||
forEach(vaeModels.entities, (vae, id) => {
|
||||
if (!vae) {
|
||||
return;
|
||||
}
|
||||
|
||||
const disabled = model?.base_model !== vae.base_model;
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: vae.model_name,
|
||||
group: MODEL_TYPE_MAP[vae.base_model],
|
||||
disabled,
|
||||
tooltip: disabled
|
||||
? `Incompatible base model: ${vae.base_model}`
|
||||
: undefined,
|
||||
});
|
||||
});
|
||||
|
||||
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
|
||||
}, [vaeModels, model?.base_model]);
|
||||
|
||||
// grab the full model entity from the RTK Query cache
|
||||
const selectedVaeModel = useMemo(
|
||||
() =>
|
||||
vaeModels?.entities[`${vae?.base_model}/vae/${vae?.model_name}`] ?? null,
|
||||
[vaeModels?.entities, vae]
|
||||
);
|
||||
|
||||
const handleChangeModel = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v || v === 'default') {
|
||||
dispatch(vaeSelected(null));
|
||||
return;
|
||||
}
|
||||
|
||||
const newVaeModel = modelIdToVAEModelParam(v);
|
||||
|
||||
if (!newVaeModel) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(vaeSelected(newVaeModel));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIMantineSearchableSelect
|
||||
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||
tooltip={selectedVaeModel?.description}
|
||||
label={t('modelManager.vae')}
|
||||
value={selectedVaeModel?.id ?? 'default'}
|
||||
placeholder="Default"
|
||||
data={data}
|
||||
onChange={handleChangeModel}
|
||||
disabled={data.length === 0}
|
||||
clearable
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamVAEModelSelect);
|
@ -17,8 +17,10 @@ import { FaPlay } from 'react-icons/fa';
|
||||
const IN_PROGRESS_STYLES: ChakraProps['sx'] = {
|
||||
_disabled: {
|
||||
bg: 'none',
|
||||
color: 'base.600',
|
||||
cursor: 'not-allowed',
|
||||
_hover: {
|
||||
color: 'base.600',
|
||||
bg: 'none',
|
||||
},
|
||||
},
|
||||
|
@ -2,6 +2,7 @@ import { useAppToaster } from 'app/components/Toaster';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { UnsafeImageMetadata } from 'services/api/endpoints/images';
|
||||
import { isImageField } from 'services/api/guards';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { initialImageSelected, modelSelected } from '../store/actions';
|
||||
@ -27,7 +28,7 @@ import {
|
||||
isValidSteps,
|
||||
isValidStrength,
|
||||
isValidWidth,
|
||||
} from '../store/parameterZodSchemas';
|
||||
} from '../types/parameterSchemas';
|
||||
|
||||
export const useRecallParameters = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
@ -162,7 +163,7 @@ export const useRecallParameters = () => {
|
||||
parameterNotSetToast();
|
||||
return;
|
||||
}
|
||||
dispatch(modelSelected(model?.id || ''));
|
||||
dispatch(modelSelected(model));
|
||||
parameterSetToast();
|
||||
},
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
@ -269,28 +270,24 @@ export const useRecallParameters = () => {
|
||||
);
|
||||
|
||||
const recallAllParameters = useCallback(
|
||||
(image: ImageDTO | undefined) => {
|
||||
if (!image || !image.metadata) {
|
||||
(metadata: UnsafeImageMetadata['metadata'] | undefined) => {
|
||||
if (!metadata) {
|
||||
allParameterNotSetToast();
|
||||
return;
|
||||
}
|
||||
|
||||
const {
|
||||
cfg_scale,
|
||||
height,
|
||||
model,
|
||||
positive_conditioning,
|
||||
negative_conditioning,
|
||||
positive_prompt,
|
||||
negative_prompt,
|
||||
scheduler,
|
||||
seed,
|
||||
steps,
|
||||
width,
|
||||
strength,
|
||||
clip,
|
||||
extra,
|
||||
latents,
|
||||
unet,
|
||||
vae,
|
||||
} = image.metadata;
|
||||
} = metadata;
|
||||
|
||||
if (isValidCfgScale(cfg_scale)) {
|
||||
dispatch(setCfgScale(cfg_scale));
|
||||
@ -298,11 +295,11 @@ export const useRecallParameters = () => {
|
||||
if (isValidMainModel(model)) {
|
||||
dispatch(modelSelected(model));
|
||||
}
|
||||
if (isValidPositivePrompt(positive_conditioning)) {
|
||||
dispatch(setPositivePrompt(positive_conditioning));
|
||||
if (isValidPositivePrompt(positive_prompt)) {
|
||||
dispatch(setPositivePrompt(positive_prompt));
|
||||
}
|
||||
if (isValidNegativePrompt(negative_conditioning)) {
|
||||
dispatch(setNegativePrompt(negative_conditioning));
|
||||
if (isValidNegativePrompt(negative_prompt)) {
|
||||
dispatch(setNegativePrompt(negative_prompt));
|
||||
}
|
||||
if (isValidScheduler(scheduler)) {
|
||||
dispatch(setScheduler(scheduler));
|
||||
|
@ -1,8 +1,10 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { ImageDTO, MainModelField } from 'services/api/types';
|
||||
|
||||
export const initialImageSelected = createAction<ImageDTO | string | undefined>(
|
||||
'generation/initialImageSelected'
|
||||
);
|
||||
|
||||
export const modelSelected = createAction<string>('generation/modelSelected');
|
||||
export const modelSelected = createAction<MainModelField>(
|
||||
'generation/modelSelected'
|
||||
);
|
||||
|
@ -8,7 +8,7 @@ import {
|
||||
setShouldShowAdvancedOptions,
|
||||
} from 'features/ui/store/uiSlice';
|
||||
import { clamp } from 'lodash-es';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { ImageDTO, MainModelField } from 'services/api/types';
|
||||
import { clipSkipMap } from '../components/Parameters/Advanced/ParamClipSkip';
|
||||
import {
|
||||
CfgScaleParam,
|
||||
@ -23,7 +23,7 @@ import {
|
||||
VaeModelParam,
|
||||
WidthParam,
|
||||
zMainModel,
|
||||
} from './parameterZodSchemas';
|
||||
} from '../types/parameterSchemas';
|
||||
|
||||
export interface GenerationState {
|
||||
cfgScale: CfgScaleParam;
|
||||
@ -54,7 +54,7 @@ export interface GenerationState {
|
||||
shouldUseSymmetry: boolean;
|
||||
horizontalSymmetrySteps: number;
|
||||
verticalSymmetrySteps: number;
|
||||
model: MainModelParam | null;
|
||||
model: MainModelField | null;
|
||||
vae: VaeModelParam | null;
|
||||
seamlessXAxis: boolean;
|
||||
seamlessYAxis: boolean;
|
||||
@ -227,24 +227,19 @@ export const generationSlice = createSlice({
|
||||
const { image_name, width, height } = action.payload;
|
||||
state.initialImage = { imageName: image_name, width, height };
|
||||
},
|
||||
modelSelected: (state, action: PayloadAction<string>) => {
|
||||
const [base_model, type, name] = action.payload.split('/');
|
||||
modelChanged: (state, action: PayloadAction<MainModelParam | null>) => {
|
||||
state.model = action.payload;
|
||||
|
||||
state.model = zMainModel.parse({
|
||||
id: action.payload,
|
||||
base_model,
|
||||
name,
|
||||
type,
|
||||
});
|
||||
if (state.model === null) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Clamp ClipSkip Based On Selected Model
|
||||
const { maxClip } = clipSkipMap[state.model.base_model];
|
||||
state.clipSkip = clamp(state.clipSkip, 0, maxClip);
|
||||
},
|
||||
modelChanged: (state, action: PayloadAction<MainModelParam>) => {
|
||||
state.model = action.payload;
|
||||
},
|
||||
vaeSelected: (state, action: PayloadAction<VaeModelParam | null>) => {
|
||||
// null is a valid VAE!
|
||||
state.vae = action.payload;
|
||||
},
|
||||
setClipSkip: (state, action: PayloadAction<number>) => {
|
||||
@ -260,11 +255,15 @@ export const generationSlice = createSlice({
|
||||
|
||||
if (defaultModel && !state.model) {
|
||||
const [base_model, model_type, model_name] = defaultModel.split('/');
|
||||
state.model = zMainModel.parse({
|
||||
id: defaultModel,
|
||||
name: model_name,
|
||||
|
||||
const result = zMainModel.safeParse({
|
||||
model_name,
|
||||
base_model,
|
||||
});
|
||||
|
||||
if (result.success) {
|
||||
state.model = result.data;
|
||||
}
|
||||
}
|
||||
});
|
||||
builder.addCase(setShouldShowAdvancedOptions, (state, action) => {
|
||||
|
@ -0,0 +1,6 @@
|
||||
export const MODEL_TYPE_MAP = {
|
||||
'sd-1': 'Stable Diffusion 1.x',
|
||||
'sd-2': 'Stable Diffusion 2.x',
|
||||
sdxl: 'Stable Diffusion XL',
|
||||
'sdxl-refiner': 'Stable Diffusion XL Refiner',
|
||||
};
|
@ -135,8 +135,7 @@ export type BaseModelParam = z.infer<typeof zBaseModel>;
|
||||
* TODO: Make this a dynamically generated enum?
|
||||
*/
|
||||
export const zMainModel = z.object({
|
||||
id: z.string(),
|
||||
name: z.string(),
|
||||
model_name: z.string().min(1),
|
||||
base_model: zBaseModel,
|
||||
});
|
||||
|
||||
@ -153,8 +152,7 @@ export const isValidMainModel = (val: unknown): val is MainModelParam =>
|
||||
* Zod schema for VAE parameter
|
||||
*/
|
||||
export const zVaeModel = z.object({
|
||||
id: z.string(),
|
||||
name: z.string(),
|
||||
model_name: z.string().min(1),
|
||||
base_model: zBaseModel,
|
||||
});
|
||||
/**
|
||||
@ -170,8 +168,7 @@ export const isValidVaeModel = (val: unknown): val is VaeModelParam =>
|
||||
* Zod schema for LoRA
|
||||
*/
|
||||
export const zLoRAModel = z.object({
|
||||
id: z.string(),
|
||||
name: z.string(),
|
||||
model_name: z.string().min(1),
|
||||
base_model: zBaseModel,
|
||||
});
|
||||
/**
|
||||
@ -183,6 +180,23 @@ export type LoRAModelParam = z.infer<typeof zLoRAModel>;
|
||||
*/
|
||||
export const isValidLoRAModel = (val: unknown): val is LoRAModelParam =>
|
||||
zLoRAModel.safeParse(val).success;
|
||||
/**
|
||||
* Zod schema for ControlNet models
|
||||
*/
|
||||
export const zControlNetModel = z.object({
|
||||
model_name: z.string().min(1),
|
||||
base_model: zBaseModel,
|
||||
});
|
||||
/**
|
||||
* Type alias for model parameter, inferred from its zod schema
|
||||
*/
|
||||
export type ControlNetModelParam = z.infer<typeof zLoRAModel>;
|
||||
/**
|
||||
* Validates/type-guards a value as a model parameter
|
||||
*/
|
||||
export const isValidControlNetModel = (
|
||||
val: unknown
|
||||
): val is ControlNetModelParam => zControlNetModel.safeParse(val).success;
|
||||
|
||||
/**
|
||||
* Zod schema for l2l strength parameter
|
@ -0,0 +1,30 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { zControlNetModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { ControlNetModelField } from 'services/api/types';
|
||||
|
||||
const moduleLog = log.child({ module: 'models' });
|
||||
|
||||
export const modelIdToControlNetModelParam = (
|
||||
controlNetModelId: string
|
||||
): ControlNetModelField | undefined => {
|
||||
const [base_model, model_type, model_name] = controlNetModelId.split('/');
|
||||
|
||||
const result = zControlNetModel.safeParse({
|
||||
base_model,
|
||||
model_name,
|
||||
});
|
||||
|
||||
if (!result.success) {
|
||||
moduleLog.error(
|
||||
{
|
||||
controlNetModelId,
|
||||
errors: result.error.format(),
|
||||
},
|
||||
'Failed to parse ControlNet model id'
|
||||
);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
return result.data;
|
||||
};
|
@ -0,0 +1,28 @@
|
||||
import { LoRAModelParam, zLoRAModel } from '../types/parameterSchemas';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
|
||||
const moduleLog = log.child({ module: 'models' });
|
||||
|
||||
export const modelIdToLoRAModelParam = (
|
||||
loraModelId: string
|
||||
): LoRAModelParam | undefined => {
|
||||
const [base_model, model_type, model_name] = loraModelId.split('/');
|
||||
|
||||
const result = zLoRAModel.safeParse({
|
||||
base_model,
|
||||
model_name,
|
||||
});
|
||||
|
||||
if (!result.success) {
|
||||
moduleLog.error(
|
||||
{
|
||||
loraModelId,
|
||||
errors: result.error.format(),
|
||||
},
|
||||
'Failed to parse LoRA model id'
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
return result.data;
|
||||
};
|
@ -0,0 +1,31 @@
|
||||
import {
|
||||
MainModelParam,
|
||||
zMainModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
|
||||
const moduleLog = log.child({ module: 'models' });
|
||||
|
||||
export const modelIdToMainModelParam = (
|
||||
mainModelId: string
|
||||
): MainModelParam | undefined => {
|
||||
const [base_model, model_type, model_name] = mainModelId.split('/');
|
||||
|
||||
const result = zMainModel.safeParse({
|
||||
base_model,
|
||||
model_name,
|
||||
});
|
||||
|
||||
if (!result.success) {
|
||||
moduleLog.error(
|
||||
{
|
||||
mainModelId,
|
||||
errors: result.error.format(),
|
||||
},
|
||||
'Failed to parse main model id'
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
return result.data;
|
||||
};
|
@ -0,0 +1,28 @@
|
||||
import { VaeModelParam, zVaeModel } from '../types/parameterSchemas';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
|
||||
const moduleLog = log.child({ module: 'models' });
|
||||
|
||||
export const modelIdToVAEModelParam = (
|
||||
vaeModelId: string
|
||||
): VaeModelParam | undefined => {
|
||||
const [base_model, model_type, model_name] = vaeModelId.split('/');
|
||||
|
||||
const result = zVaeModel.safeParse({
|
||||
base_model,
|
||||
model_name,
|
||||
});
|
||||
|
||||
if (!result.success) {
|
||||
moduleLog.error(
|
||||
{
|
||||
vaeModelId,
|
||||
errors: result.error.format(),
|
||||
},
|
||||
'Failed to parse VAE model id'
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
return result.data;
|
||||
};
|
Reference in New Issue
Block a user