fix: Adjust clip skip layer count based on model (#3675)

Clip Skip breaks when you supply a number greater than the number of
layers for the model type. So capping this out based on the model on the
frontend

- `sd-1` at 12
- `sd-2` at 24
- Will update later to whatever SDXL needs if it is different.

- Also fixes LoRA's breaking with Clip Skip.
This commit is contained in:
blessedcoolant 2023-07-07 23:46:09 +12:00 committed by GitHub
commit 909fe047e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 36 additions and 6 deletions

View File

@ -4,6 +4,7 @@ import { forEach, size } from 'lodash-es';
import { LoraLoaderInvocation } from 'services/api/types'; import { LoraLoaderInvocation } from 'services/api/types';
import { modelIdToLoRAModelField } from '../modelIdToLoRAName'; import { modelIdToLoRAModelField } from '../modelIdToLoRAName';
import { import {
CLIP_SKIP,
LORA_LOADER, LORA_LOADER,
MAIN_MODEL_LOADER, MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
@ -27,14 +28,19 @@ export const addLoRAsToGraph = (
const loraCount = size(loras); const loraCount = size(loras);
if (loraCount > 0) { if (loraCount > 0) {
// remove any existing connections from main model loader, we need to insert the lora nodes // Remove MAIN_MODEL_LOADER unet connection to feed it to LoRAs
graph.edges = graph.edges.filter( graph.edges = graph.edges.filter(
(e) => (e) =>
!( !(
e.source.node_id === MAIN_MODEL_LOADER && e.source.node_id === MAIN_MODEL_LOADER &&
['unet', 'clip'].includes(e.source.field) ['unet'].includes(e.source.field)
) )
); );
// Remove CLIP_SKIP connections to conditionings to feed it through LoRAs
graph.edges = graph.edges.filter(
(e) =>
!(e.source.node_id === CLIP_SKIP && ['clip'].includes(e.source.field))
);
} }
// we need to remember the last lora so we can chain from it // we need to remember the last lora so we can chain from it
@ -73,7 +79,7 @@ export const addLoRAsToGraph = (
graph.edges.push({ graph.edges.push({
source: { source: {
node_id: MAIN_MODEL_LOADER, node_id: CLIP_SKIP,
field: 'clip', field: 'clip',
}, },
destination: { destination: {

View File

@ -11,7 +11,7 @@ const selector = createSelector(
(state: RootState) => { (state: RootState) => {
const clipSkip = state.generation.clipSkip; const clipSkip = state.generation.clipSkip;
return { return {
activeLabel: clipSkip > 0 ? `Clip Skip Active` : undefined, activeLabel: clipSkip > 0 ? 'Clip Skip' : undefined,
}; };
}, },
defaultSelectorOptions defaultSelectorOptions

View File

@ -5,11 +5,26 @@ import { setClipSkip } from 'features/parameters/store/generationSlice';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
export const clipSkipMap = {
'sd-1': {
maxClip: 12,
markers: [0, 1, 2, 3, 4, 8, 12],
},
'sd-2': {
maxClip: 24,
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
},
};
export default function ParamClipSkip() { export default function ParamClipSkip() {
const clipSkip = useAppSelector( const clipSkip = useAppSelector(
(state: RootState) => state.generation.clipSkip (state: RootState) => state.generation.clipSkip
); );
const selectedModelId = useAppSelector(
(state: RootState) => state.generation.model
).split('/')[0];
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
@ -29,12 +44,14 @@ export default function ParamClipSkip() {
label={t('parameters.clipSkip')} label={t('parameters.clipSkip')}
aria-label={t('parameters.clipSkip')} aria-label={t('parameters.clipSkip')}
min={0} min={0}
max={30} max={clipSkipMap[selectedModelId as keyof typeof clipSkipMap].maxClip}
step={1} step={1}
value={clipSkip} value={clipSkip}
onChange={handleClipSkipChange} onChange={handleClipSkipChange}
withSliderMarks withSliderMarks
sliderMarks={[0, 1, 2, 3, 5, 10, 15, 25, 30]} sliderMarks={
clipSkipMap[selectedModelId as keyof typeof clipSkipMap].markers
}
withInput withInput
withReset withReset
handleReset={handleClipSkipReset} handleReset={handleClipSkipReset}

View File

@ -5,6 +5,7 @@ import { configChanged } from 'features/system/store/configSlice';
import { setShouldShowAdvancedOptions } from 'features/ui/store/uiSlice'; import { setShouldShowAdvancedOptions } from 'features/ui/store/uiSlice';
import { clamp } from 'lodash-es'; import { clamp } from 'lodash-es';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
import { clipSkipMap } from '../components/Parameters/Advanced/ParamClipSkip';
import { import {
CfgScaleParam, CfgScaleParam,
HeightParam, HeightParam,
@ -216,6 +217,12 @@ export const generationSlice = createSlice({
}, },
modelSelected: (state, action: PayloadAction<string>) => { modelSelected: (state, action: PayloadAction<string>) => {
state.model = action.payload; state.model = action.payload;
// Clamp ClipSkip Based On Selected Model
const clipSkipMax =
clipSkipMap[action.payload.split('/')[0] as keyof typeof clipSkipMap]
.maxClip;
state.clipSkip = clamp(state.clipSkip, 0, clipSkipMax);
}, },
vaeSelected: (state, action: PayloadAction<string>) => { vaeSelected: (state, action: PayloadAction<string>) => {
state.vae = action.payload; state.vae = action.payload;