mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix: Adjust clip skip layer count based on model (#3675)
Clip Skip breaks when you supply a number greater than the number of layers for the model type. So capping this out based on the model on the frontend - `sd-1` at 12 - `sd-2` at 24 - Will update later to whatever SDXL needs if it is different. - Also fixes LoRA's breaking with Clip Skip.
This commit is contained in:
commit
909fe047e4
@ -4,6 +4,7 @@ import { forEach, size } from 'lodash-es';
|
|||||||
import { LoraLoaderInvocation } from 'services/api/types';
|
import { LoraLoaderInvocation } from 'services/api/types';
|
||||||
import { modelIdToLoRAModelField } from '../modelIdToLoRAName';
|
import { modelIdToLoRAModelField } from '../modelIdToLoRAName';
|
||||||
import {
|
import {
|
||||||
|
CLIP_SKIP,
|
||||||
LORA_LOADER,
|
LORA_LOADER,
|
||||||
MAIN_MODEL_LOADER,
|
MAIN_MODEL_LOADER,
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
@ -27,14 +28,19 @@ export const addLoRAsToGraph = (
|
|||||||
const loraCount = size(loras);
|
const loraCount = size(loras);
|
||||||
|
|
||||||
if (loraCount > 0) {
|
if (loraCount > 0) {
|
||||||
// remove any existing connections from main model loader, we need to insert the lora nodes
|
// Remove MAIN_MODEL_LOADER unet connection to feed it to LoRAs
|
||||||
graph.edges = graph.edges.filter(
|
graph.edges = graph.edges.filter(
|
||||||
(e) =>
|
(e) =>
|
||||||
!(
|
!(
|
||||||
e.source.node_id === MAIN_MODEL_LOADER &&
|
e.source.node_id === MAIN_MODEL_LOADER &&
|
||||||
['unet', 'clip'].includes(e.source.field)
|
['unet'].includes(e.source.field)
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
// Remove CLIP_SKIP connections to conditionings to feed it through LoRAs
|
||||||
|
graph.edges = graph.edges.filter(
|
||||||
|
(e) =>
|
||||||
|
!(e.source.node_id === CLIP_SKIP && ['clip'].includes(e.source.field))
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// we need to remember the last lora so we can chain from it
|
// we need to remember the last lora so we can chain from it
|
||||||
@ -73,7 +79,7 @@ export const addLoRAsToGraph = (
|
|||||||
|
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: {
|
source: {
|
||||||
node_id: MAIN_MODEL_LOADER,
|
node_id: CLIP_SKIP,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
|
@ -11,7 +11,7 @@ const selector = createSelector(
|
|||||||
(state: RootState) => {
|
(state: RootState) => {
|
||||||
const clipSkip = state.generation.clipSkip;
|
const clipSkip = state.generation.clipSkip;
|
||||||
return {
|
return {
|
||||||
activeLabel: clipSkip > 0 ? `Clip Skip Active` : undefined,
|
activeLabel: clipSkip > 0 ? 'Clip Skip' : undefined,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
|
@ -5,11 +5,26 @@ import { setClipSkip } from 'features/parameters/store/generationSlice';
|
|||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
export const clipSkipMap = {
|
||||||
|
'sd-1': {
|
||||||
|
maxClip: 12,
|
||||||
|
markers: [0, 1, 2, 3, 4, 8, 12],
|
||||||
|
},
|
||||||
|
'sd-2': {
|
||||||
|
maxClip: 24,
|
||||||
|
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
export default function ParamClipSkip() {
|
export default function ParamClipSkip() {
|
||||||
const clipSkip = useAppSelector(
|
const clipSkip = useAppSelector(
|
||||||
(state: RootState) => state.generation.clipSkip
|
(state: RootState) => state.generation.clipSkip
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const selectedModelId = useAppSelector(
|
||||||
|
(state: RootState) => state.generation.model
|
||||||
|
).split('/')[0];
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
@ -29,12 +44,14 @@ export default function ParamClipSkip() {
|
|||||||
label={t('parameters.clipSkip')}
|
label={t('parameters.clipSkip')}
|
||||||
aria-label={t('parameters.clipSkip')}
|
aria-label={t('parameters.clipSkip')}
|
||||||
min={0}
|
min={0}
|
||||||
max={30}
|
max={clipSkipMap[selectedModelId as keyof typeof clipSkipMap].maxClip}
|
||||||
step={1}
|
step={1}
|
||||||
value={clipSkip}
|
value={clipSkip}
|
||||||
onChange={handleClipSkipChange}
|
onChange={handleClipSkipChange}
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
sliderMarks={[0, 1, 2, 3, 5, 10, 15, 25, 30]}
|
sliderMarks={
|
||||||
|
clipSkipMap[selectedModelId as keyof typeof clipSkipMap].markers
|
||||||
|
}
|
||||||
withInput
|
withInput
|
||||||
withReset
|
withReset
|
||||||
handleReset={handleClipSkipReset}
|
handleReset={handleClipSkipReset}
|
||||||
|
@ -5,6 +5,7 @@ import { configChanged } from 'features/system/store/configSlice';
|
|||||||
import { setShouldShowAdvancedOptions } from 'features/ui/store/uiSlice';
|
import { setShouldShowAdvancedOptions } from 'features/ui/store/uiSlice';
|
||||||
import { clamp } from 'lodash-es';
|
import { clamp } from 'lodash-es';
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
|
import { clipSkipMap } from '../components/Parameters/Advanced/ParamClipSkip';
|
||||||
import {
|
import {
|
||||||
CfgScaleParam,
|
CfgScaleParam,
|
||||||
HeightParam,
|
HeightParam,
|
||||||
@ -216,6 +217,12 @@ export const generationSlice = createSlice({
|
|||||||
},
|
},
|
||||||
modelSelected: (state, action: PayloadAction<string>) => {
|
modelSelected: (state, action: PayloadAction<string>) => {
|
||||||
state.model = action.payload;
|
state.model = action.payload;
|
||||||
|
|
||||||
|
// Clamp ClipSkip Based On Selected Model
|
||||||
|
const clipSkipMax =
|
||||||
|
clipSkipMap[action.payload.split('/')[0] as keyof typeof clipSkipMap]
|
||||||
|
.maxClip;
|
||||||
|
state.clipSkip = clamp(state.clipSkip, 0, clipSkipMax);
|
||||||
},
|
},
|
||||||
vaeSelected: (state, action: PayloadAction<string>) => {
|
vaeSelected: (state, action: PayloadAction<string>) => {
|
||||||
state.vae = action.payload;
|
state.vae = action.payload;
|
||||||
|
Loading…
Reference in New Issue
Block a user