feat(ui): layers recall

This still needs some finessing - needs logic depending on the tab...
This commit is contained in:
psychedelicious 2024-05-07 17:47:09 +10:00 committed by Kent Keirsey
parent ccd399e277
commit e537de2f6d
7 changed files with 112 additions and 3 deletions

View File

@ -1559,7 +1559,9 @@
"opacityFilter": "Opacity Filter", "opacityFilter": "Opacity Filter",
"clearProcessor": "Clear Processor", "clearProcessor": "Clear Processor",
"resetProcessor": "Reset Processor to Defaults", "resetProcessor": "Reset Processor to Defaults",
"noLayersAdded": "No Layers Added" "noLayersAdded": "No Layers Added",
"layers_one": "Layer",
"layers_other": "Layers"
}, },
"ui": { "ui": {
"tabs": { "tabs": {

View File

@ -255,6 +255,10 @@ export const controlLayersSlice = createSlice({
payload: { layerId: uuidv4(), controlAdapter }, payload: { layerId: uuidv4(), controlAdapter },
}), }),
}, },
caLayerRecalled: (state, action: PayloadAction<ControlAdapterLayer>) => {
state.layers.push({ ...action.payload, isSelected: true });
state.selectedLayerId = action.payload.id;
},
caLayerImageChanged: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO | null }>) => { caLayerImageChanged: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO | null }>) => {
const { layerId, imageDTO } = action.payload; const { layerId, imageDTO } = action.payload;
const layer = selectCALayerOrThrow(state, layerId); const layer = selectCALayerOrThrow(state, layerId);
@ -368,6 +372,9 @@ export const controlLayersSlice = createSlice({
}, },
prepare: (ipAdapter: IPAdapterConfigV2) => ({ payload: { layerId: uuidv4(), ipAdapter } }), prepare: (ipAdapter: IPAdapterConfigV2) => ({ payload: { layerId: uuidv4(), ipAdapter } }),
}, },
ipaLayerRecalled: (state, action: PayloadAction<IPAdapterLayer>) => {
state.layers.push(action.payload);
},
ipaLayerImageChanged: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO | null }>) => { ipaLayerImageChanged: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO | null }>) => {
const { layerId, imageDTO } = action.payload; const { layerId, imageDTO } = action.payload;
const layer = selectIPALayerOrThrow(state, layerId); const layer = selectIPALayerOrThrow(state, layerId);
@ -462,6 +469,10 @@ export const controlLayersSlice = createSlice({
}, },
prepare: () => ({ payload: { layerId: uuidv4() } }), prepare: () => ({ payload: { layerId: uuidv4() } }),
}, },
rgLayerRecalled: (state, action: PayloadAction<RegionalGuidanceLayer>) => {
state.layers.push({ ...action.payload, isSelected: true });
state.selectedLayerId = action.payload.id;
},
rgLayerPositivePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string | null }>) => { rgLayerPositivePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string | null }>) => {
const { layerId, prompt } = action.payload; const { layerId, prompt } = action.payload;
const layer = selectRGLayerOrThrow(state, layerId); const layer = selectRGLayerOrThrow(state, layerId);
@ -805,6 +816,7 @@ export const {
allLayersDeleted, allLayersDeleted,
// CA Layers // CA Layers
caLayerAdded, caLayerAdded,
caLayerRecalled,
caLayerImageChanged, caLayerImageChanged,
caLayerProcessedImageChanged, caLayerProcessedImageChanged,
caLayerModelChanged, caLayerModelChanged,
@ -817,6 +829,7 @@ export const {
caLayerT2IAdaptersDeleted, caLayerT2IAdaptersDeleted,
// IPA Layers // IPA Layers
ipaLayerAdded, ipaLayerAdded,
ipaLayerRecalled,
ipaLayerImageChanged, ipaLayerImageChanged,
ipaLayerMethodChanged, ipaLayerMethodChanged,
ipaLayerModelChanged, ipaLayerModelChanged,
@ -827,6 +840,7 @@ export const {
caOrIPALayerBeginEndStepPctChanged, caOrIPALayerBeginEndStepPctChanged,
// RG Layers // RG Layers
rgLayerAdded, rgLayerAdded,
rgLayerRecalled,
rgLayerPositivePromptChanged, rgLayerPositivePromptChanged,
rgLayerNegativePromptChanged, rgLayerNegativePromptChanged,
rgLayerPreviewColorChanged, rgLayerPreviewColorChanged,

View File

@ -51,6 +51,7 @@ const ImageMetadataActions = (props: Props) => {
<MetadataItem metadata={metadata} handlers={handlers.refinerScheduler} /> <MetadataItem metadata={metadata} handlers={handlers.refinerScheduler} />
<MetadataItem metadata={metadata} handlers={handlers.refinerStart} /> <MetadataItem metadata={metadata} handlers={handlers.refinerStart} />
<MetadataItem metadata={metadata} handlers={handlers.refinerSteps} /> <MetadataItem metadata={metadata} handlers={handlers.refinerSteps} />
<MetadataItem metadata={metadata} handlers={handlers.layers} />
<MetadataLoRAs metadata={metadata} /> <MetadataLoRAs metadata={metadata} />
{activeTabName !== 'generation' && <MetadataControlNets metadata={metadata} />} {activeTabName !== 'generation' && <MetadataControlNets metadata={metadata} />}
{activeTabName !== 'generation' && <MetadataT2IAdapters metadata={metadata} />} {activeTabName !== 'generation' && <MetadataT2IAdapters metadata={metadata} />}

View File

@ -1,5 +1,6 @@
import { objectKeys } from 'common/util/objectKeys'; import { objectKeys } from 'common/util/objectKeys';
import { toast } from 'common/util/toast'; import { toast } from 'common/util/toast';
import type { Layer } from 'features/controlLayers/store/types';
import type { LoRA } from 'features/lora/store/loraSlice'; import type { LoRA } from 'features/lora/store/loraSlice';
import type { import type {
AnyControlAdapterConfigMetadata, AnyControlAdapterConfigMetadata,
@ -52,6 +53,9 @@ const renderControlAdapterValueV2: MetadataRenderValueFunc<AnyControlAdapterConf
return `${value.model.key} (${value.model.base.toUpperCase()}) - ${value.weight}`; return `${value.model.key} (${value.model.base.toUpperCase()}) - ${value.weight}`;
} }
}; };
const renderLayersValue: MetadataRenderValueFunc<Layer[]> = async (value) => {
return `${value.length} ${t('controlLayers.layers', { count: value.length })}`;
};
const parameterSetToast = (parameter: string, description?: string) => { const parameterSetToast = (parameter: string, description?: string) => {
toast({ toast({
@ -171,6 +175,7 @@ const buildHandlers: BuildMetadataHandlers = ({
itemValidator, itemValidator,
renderValue, renderValue,
renderItemValue, renderItemValue,
getIsVisible,
}) => ({ }) => ({
parse: buildParse({ parser, getLabel }), parse: buildParse({ parser, getLabel }),
parseItem: itemParser ? buildParseItem({ itemParser, getLabel }) : undefined, parseItem: itemParser ? buildParseItem({ itemParser, getLabel }) : undefined,
@ -179,6 +184,7 @@ const buildHandlers: BuildMetadataHandlers = ({
getLabel, getLabel,
renderValue: renderValue ?? resolveToString, renderValue: renderValue ?? resolveToString,
renderItemValue: renderItemValue ?? resolveToString, renderItemValue: renderItemValue ?? resolveToString,
getIsVisible,
}); });
export const handlers = { export const handlers = {
@ -380,6 +386,14 @@ export const handlers = {
itemValidator: validators.t2iAdapterV2, itemValidator: validators.t2iAdapterV2,
renderItemValue: renderControlAdapterValueV2, renderItemValue: renderControlAdapterValueV2,
}), }),
layers: buildHandlers({
getLabel: () => t('controlLayers.layers_other'),
parser: parsers.layers,
recaller: recallers.layers,
validator: validators.layers,
renderValue: renderLayersValue,
getIsVisible: (value) => value.length > 0,
}),
} as const; } as const;
export const parseAndRecallPrompts = async (metadata: unknown) => { export const parseAndRecallPrompts = async (metadata: unknown) => {
@ -435,9 +449,22 @@ export const parseAndRecallImageDimensions = async (metadata: unknown) => {
}; };
// These handlers should be omitted when recalling to control layers // These handlers should be omitted when recalling to control layers
const TO_CONTROL_LAYERS_SKIP_KEYS: (keyof typeof handlers)[] = ['controlNets', 'ipAdapters', 't2iAdapters']; const TO_CONTROL_LAYERS_SKIP_KEYS: (keyof typeof handlers)[] = [
'controlNets',
'ipAdapters',
't2iAdapters',
'controlNetsV2',
'ipAdaptersV2',
't2iAdaptersV2',
];
// These handlers should be omitted when recalling to the rest of the app // These handlers should be omitted when recalling to the rest of the app
const NOT_TO_CONTROL_LAYERS_SKIP_KEYS: (keyof typeof handlers)[] = ['controlNetsV2', 'ipAdaptersV2', 't2iAdaptersV2']; const NOT_TO_CONTROL_LAYERS_SKIP_KEYS: (keyof typeof handlers)[] = [
'controlNetsV2',
'ipAdaptersV2',
't2iAdaptersV2',
'initialImage',
'layers',
];
export const parseAndRecallAllMetadata = async ( export const parseAndRecallAllMetadata = async (
metadata: unknown, metadata: unknown,

View File

@ -5,6 +5,8 @@ import {
initialT2IAdapter, initialT2IAdapter,
} from 'features/controlAdapters/util/buildControlAdapter'; } from 'features/controlAdapters/util/buildControlAdapter';
import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor'; import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
import type { Layer } from 'features/controlLayers/store/types';
import { zLayer } from 'features/controlLayers/store/types';
import { import {
CA_PROCESSOR_DATA, CA_PROCESSOR_DATA,
imageDTOToImageWithDims, imageDTOToImageWithDims,
@ -623,6 +625,19 @@ const parseIPAdapterV2: MetadataParseFunc<IPAdapterConfigV2Metadata> = async (me
return ipAdapter; return ipAdapter;
}; };
const parseLayers: MetadataParseFunc<Layer[]> = async (metadata) => {
try {
const layersRaw = await getProperty(metadata, 'layers', isArray);
const parseResults = await Promise.allSettled(layersRaw.map((layerRaw) => zLayer.parseAsync(layerRaw)));
const layers = parseResults
.filter((result): result is PromiseFulfilledResult<Layer> => result.status === 'fulfilled')
.map((result) => result.value);
return layers;
} catch {
return [];
}
};
const parseAllIPAdaptersV2: MetadataParseFunc<IPAdapterConfigV2Metadata[]> = async (metadata) => { const parseAllIPAdaptersV2: MetadataParseFunc<IPAdapterConfigV2Metadata[]> = async (metadata) => {
try { try {
const ipAdaptersRaw = await getProperty(metadata, 'ipAdapters', isArray); const ipAdaptersRaw = await getProperty(metadata, 'ipAdapters', isArray);
@ -678,4 +693,5 @@ export const parsers = {
t2iAdaptersV2: parseAllT2IAdaptersV2, t2iAdaptersV2: parseAllT2IAdaptersV2,
ipAdapterV2: parseIPAdapterV2, ipAdapterV2: parseIPAdapterV2,
ipAdaptersV2: parseAllIPAdaptersV2, ipAdaptersV2: parseAllIPAdaptersV2,
layers: parseLayers,
} as const; } as const;

View File

@ -6,19 +6,24 @@ import {
t2iAdaptersReset, t2iAdaptersReset,
} from 'features/controlAdapters/store/controlAdaptersSlice'; } from 'features/controlAdapters/store/controlAdaptersSlice';
import { import {
allLayersDeleted,
caLayerAdded, caLayerAdded,
caLayerControlNetsDeleted, caLayerControlNetsDeleted,
caLayerRecalled,
caLayerT2IAdaptersDeleted, caLayerT2IAdaptersDeleted,
heightChanged, heightChanged,
iiLayerAdded, iiLayerAdded,
ipaLayerAdded, ipaLayerAdded,
ipaLayerRecalled,
ipaLayersDeleted, ipaLayersDeleted,
negativePrompt2Changed, negativePrompt2Changed,
negativePromptChanged, negativePromptChanged,
positivePrompt2Changed, positivePrompt2Changed,
positivePromptChanged, positivePromptChanged,
rgLayerRecalled,
widthChanged, widthChanged,
} from 'features/controlLayers/store/controlLayersSlice'; } from 'features/controlLayers/store/controlLayersSlice';
import type { Layer } from 'features/controlLayers/store/types';
import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice'; import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice';
import type { LoRA } from 'features/lora/store/loraSlice'; import type { LoRA } from 'features/lora/store/loraSlice';
import { loraRecalled, lorasReset } from 'features/lora/store/loraSlice'; import { loraRecalled, lorasReset } from 'features/lora/store/loraSlice';
@ -290,6 +295,24 @@ const recallIPAdaptersV2: MetadataRecallFunc<IPAdapterConfigV2Metadata[]> = (ipA
}); });
}; };
const recallLayers: MetadataRecallFunc<Layer[]> = (layers) => {
const { dispatch } = getStore();
dispatch(allLayersDeleted());
for (const l of layers) {
switch (l.type) {
case 'control_adapter_layer':
dispatch(caLayerRecalled(l));
break;
case 'ip_adapter_layer':
dispatch(ipaLayerRecalled(l));
break;
case 'regional_guidance_layer':
dispatch(rgLayerRecalled(l));
break;
}
}
};
export const recallers = { export const recallers = {
positivePrompt: recallPositivePrompt, positivePrompt: recallPositivePrompt,
negativePrompt: recallNegativePrompt, negativePrompt: recallNegativePrompt,
@ -330,4 +353,5 @@ export const recallers = {
t2iAdaptersV2: recallT2IAdaptersV2, t2iAdaptersV2: recallT2IAdaptersV2,
ipAdapterV2: recallIPAdapterV2, ipAdapterV2: recallIPAdapterV2,
ipAdaptersV2: recallIPAdaptersV2, ipAdaptersV2: recallIPAdaptersV2,
layers: recallLayers,
} as const; } as const;

View File

@ -1,4 +1,5 @@
import { getStore } from 'app/store/nanostores/store'; import { getStore } from 'app/store/nanostores/store';
import type { Layer } from 'features/controlLayers/store/types';
import type { LoRA } from 'features/lora/store/loraSlice'; import type { LoRA } from 'features/lora/store/loraSlice';
import type { import type {
ControlNetConfigMetadata, ControlNetConfigMetadata,
@ -165,6 +166,29 @@ const validateIPAdaptersV2: MetadataValidateFunc<IPAdapterConfigV2Metadata[]> =
return new Promise((resolve) => resolve(validatedIPAdapters)); return new Promise((resolve) => resolve(validatedIPAdapters));
}; };
const validateLayers: MetadataValidateFunc<Layer[]> = (layers) => {
const validatedLayers: Layer[] = [];
for (const l of layers) {
try {
if (l.type === 'control_adapter_layer') {
validateBaseCompatibility(l.controlAdapter.model?.base, 'Layer incompatible with currently-selected model');
}
if (l.type === 'ip_adapter_layer') {
validateBaseCompatibility(l.ipAdapter.model?.base, 'Layer incompatible with currently-selected model');
}
if (l.type === 'regional_guidance_layer') {
for (const ipa of l.ipAdapters) {
validateBaseCompatibility(ipa.model?.base, 'Layer incompatible with currently-selected model');
}
}
validatedLayers.push(l);
} catch {
// This is a no-op - we want to continue validating the rest of the layers, and an empty list is valid.
}
}
return new Promise((resolve) => resolve(validatedLayers));
};
export const validators = { export const validators = {
refinerModel: validateRefinerModel, refinerModel: validateRefinerModel,
vaeModel: validateVAEModel, vaeModel: validateVAEModel,
@ -182,4 +206,5 @@ export const validators = {
t2iAdaptersV2: validateT2IAdaptersV2, t2iAdaptersV2: validateT2IAdaptersV2,
ipAdapterV2: validateIPAdapterV2, ipAdapterV2: validateIPAdapterV2,
ipAdaptersV2: validateIPAdaptersV2, ipAdaptersV2: validateIPAdaptersV2,
layers: validateLayers,
} as const; } as const;