feat(ui): use integrated metadata helper

This commit is contained in:
psychedelicious 2024-05-14 13:45:25 +10:00
parent ee647a05dc
commit 48ccd63dba
10 changed files with 12 additions and 185 deletions

View File

@ -1,91 +0,0 @@
import { isModelIdentifier } from 'features/nodes/types/common';
import { Graph } from 'features/nodes/util/graph/Graph';
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
import { pick } from 'lodash-es';
import type { AnyModelConfig } from 'services/api/types';
import { AssertionError } from 'tsafe';
import { describe, expect, it } from 'vitest';
describe('MetadataUtil', () => {
describe('getNode', () => {
it('should return the metadata node if one exists', () => {
const g = new Graph();
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
const metadataNode = g.addNode({ id: MetadataUtil.metadataNodeId, type: 'core_metadata' });
expect(MetadataUtil.getNode(g)).toEqual(metadataNode);
});
it('should raise an error if the metadata node does not exist', () => {
const g = new Graph();
expect(() => MetadataUtil.getNode(g)).toThrowError(AssertionError);
});
});
describe('add', () => {
const g = new Graph();
it("should add metadata, creating the node if it doesn't exist", () => {
MetadataUtil.add(g, { foo: 'bar' });
const metadataNode = MetadataUtil.getNode(g);
expect(metadataNode['type']).toBe('core_metadata');
expect(metadataNode['foo']).toBe('bar');
});
it('should update existing metadata keys', () => {
const updatedMetadataNode = MetadataUtil.add(g, { foo: 'bananas', baz: 'qux' });
expect(updatedMetadataNode['foo']).toBe('bananas');
expect(updatedMetadataNode['baz']).toBe('qux');
});
});
describe('remove', () => {
it('should remove a single key', () => {
const g = new Graph();
MetadataUtil.add(g, { foo: 'bar', baz: 'qux' });
const updatedMetadataNode = MetadataUtil.remove(g, 'foo');
expect(updatedMetadataNode['foo']).toBeUndefined();
expect(updatedMetadataNode['baz']).toBe('qux');
});
it('should remove multiple keys', () => {
const g = new Graph();
MetadataUtil.add(g, { foo: 'bar', baz: 'qux' });
const updatedMetadataNode = MetadataUtil.remove(g, ['foo', 'baz']);
expect(updatedMetadataNode['foo']).toBeUndefined();
expect(updatedMetadataNode['baz']).toBeUndefined();
});
});
describe('setMetadataReceivingNode', () => {
const g = new Graph();
it('should add an edge from from metadata to the receiving node', () => {
const n = g.addNode({ id: 'my-node', type: 'img_resize' });
MetadataUtil.add(g, { foo: 'bar' });
MetadataUtil.setMetadataReceivingNode(g, n);
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
expect(g.hasEdge(MetadataUtil.getNode(g), 'metadata', n, 'metadata')).toBe(true);
});
it('should remove existing metadata edges', () => {
const n2 = g.addNode({ id: 'my-other-node', type: 'img_resize' });
MetadataUtil.setMetadataReceivingNode(g, n2);
expect(g.getIncomers(n2).length).toBe(1);
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
expect(g.hasEdge(MetadataUtil.getNode(g), 'metadata', n2, 'metadata')).toBe(true);
});
});
describe('getModelMetadataField', () => {
it('should return a ModelIdentifierField', () => {
const model: AnyModelConfig = {
key: 'model_key',
type: 'main',
hash: 'model_hash',
base: 'sd-1',
format: 'diffusers',
name: 'my model',
path: '/some/path',
source: 'www.models.com',
source_type: 'url',
};
const metadataField = MetadataUtil.getModelMetadataField(model);
expect(isModelIdentifier(metadataField)).toBe(true);
expect(pick(model, ['key', 'hash', 'name', 'base', 'type'])).toEqual(metadataField);
});
});
});

View File

@ -1,74 +0,0 @@
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { METADATA } from 'features/nodes/util/graph/constants';
import { isString, unset } from 'lodash-es';
import type {
AnyInvocation,
AnyInvocationIncMetadata,
AnyModelConfig,
CoreMetadataInvocation,
S,
} from 'services/api/types';
import { assert } from 'tsafe';
import type { Graph } from './Graph';
const isCoreMetadata = (node: S['Graph']['nodes'][string]): node is CoreMetadataInvocation =>
node.type === 'core_metadata';
export class MetadataUtil {
static metadataNodeId = METADATA;
static getNode(g: Graph): CoreMetadataInvocation {
const node = g.getNode(this.metadataNodeId) as AnyInvocationIncMetadata;
assert(isCoreMetadata(node));
return node;
}
static add(g: Graph, metadata: Partial<CoreMetadataInvocation>): CoreMetadataInvocation {
try {
const node = g.getNode(this.metadataNodeId) as AnyInvocationIncMetadata;
assert(isCoreMetadata(node));
Object.assign(node, metadata);
return node;
} catch {
const metadataNode: CoreMetadataInvocation = {
id: this.metadataNodeId,
type: 'core_metadata',
...metadata,
};
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
return g.addNode(metadataNode);
}
}
static remove(g: Graph, key: string): CoreMetadataInvocation;
static remove(g: Graph, keys: string[]): CoreMetadataInvocation;
static remove(g: Graph, keyOrKeys: string | string[]): CoreMetadataInvocation {
const metadataNode = this.getNode(g);
if (isString(keyOrKeys)) {
unset(metadataNode, keyOrKeys);
} else {
for (const key of keyOrKeys) {
unset(metadataNode, key);
}
}
return metadataNode;
}
static setMetadataReceivingNode(g: Graph, node: AnyInvocation): void {
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
g.deleteEdgesFrom(this.getNode(g));
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
g.addEdge(this.getNode(g), 'metadata', node, 'metadata');
}
static getModelMetadataField({ key, hash, name, base, type }: AnyModelConfig): ModelIdentifierField {
return {
key,
hash,
name,
base,
type,
};
}
}

View File

@ -31,7 +31,6 @@ import {
T2I_ADAPTER_COLLECT,
} from 'features/nodes/util/graph/constants';
import type { Graph } from 'features/nodes/util/graph/Graph';
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
import { size } from 'lodash-es';
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types';
@ -245,7 +244,7 @@ export const addGenerationTabControlLayers = async (
}
}
MetadataUtil.add(g, { control_layers: { layers: validLayers, version: state.controlLayers.present._version } });
g.upsertMetadata({ control_layers: { layers: validLayers, version: state.controlLayers.present._version } });
return validLayers;
};
@ -490,7 +489,7 @@ const addInitialImageLayerToGraph = (
g.addEdge(i2l, 'height', noise, 'height');
}
MetadataUtil.add(g, { generation_mode: isSDXL ? 'sdxl_img2img' : 'img2img' });
g.upsertMetadata({ generation_mode: isSDXL ? 'sdxl_img2img' : 'img2img' });
};
const isValidControlAdapter = (ca: ControlNetConfigV2 | T2IAdapterConfigV2, base: BaseModelType): boolean => {

View File

@ -3,7 +3,6 @@ import { deepClone } from 'common/util/deepClone';
import { roundToMultiple } from 'common/util/roundDownToMultiple';
import type { Graph } from 'features/nodes/util/graph/Graph';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import type { Invocation } from 'services/api/types';
@ -157,12 +156,12 @@ export const addGenerationTabHRF = (
g.addEdge(vaeSource, 'vae', l2iHrfHR, 'vae');
g.addEdge(denoiseHrf, 'latents', l2iHrfHR, 'latents');
MetadataUtil.add(g, {
g.upsertMetadata({
hrf_strength: hrfStrength,
hrf_enabled: hrfEnabled,
hrf_method: hrfMethod,
});
MetadataUtil.setMetadataReceivingNode(g, l2iHrfHR);
g.setMetadataReceivingNode(l2iHrfHR);
return l2iHrfHR;
};

View File

@ -1,7 +1,6 @@
import type { RootState } from 'app/store/store';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { Graph } from 'features/nodes/util/graph/Graph';
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
import { filter, size } from 'lodash-es';
import type { Invocation, S } from 'services/api/types';
@ -69,5 +68,5 @@ export const addGenerationTabLoRAs = (
g.addEdge(loraSelector, 'lora', loraCollector, 'item');
}
MetadataUtil.add(g, { loras: loraMetadata });
g.upsertMetadata({ loras: loraMetadata });
};

View File

@ -1,7 +1,6 @@
import type { RootState } from 'app/store/store';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { Graph } from 'features/nodes/util/graph/Graph';
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
import { filter, size } from 'lodash-es';
import type { Invocation, S } from 'services/api/types';
@ -71,5 +70,5 @@ export const addGenerationTabSDXLLoRAs = (
g.addEdge(loraSelector, 'lora', loraCollector, 'item');
}
MetadataUtil.add(g, { loras: loraMetadata });
g.upsertMetadata({ loras: loraMetadata });
};

View File

@ -1,7 +1,6 @@
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import type { Graph } from 'features/nodes/util/graph/Graph';
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
import type { Invocation } from 'services/api/types';
import { isRefinerMainModelModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
@ -92,7 +91,7 @@ export const addGenerationTabSDXLRefiner = async (
g.addEdge(denoise, 'latents', refinerDenoise, 'latents');
g.addEdge(refinerDenoise, 'latents', l2i, 'latents');
MetadataUtil.add(g, {
g.upsertMetadata({
refiner_model: getModelMetadataField(modelConfig),
refiner_positive_aesthetic_score: refinerPositiveAestheticScore,
refiner_negative_aesthetic_score: refinerNegativeAestheticScore,

View File

@ -1,6 +1,5 @@
import type { RootState } from 'app/store/store';
import type { Graph } from 'features/nodes/util/graph/Graph';
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
import type { Invocation } from 'services/api/types';
import { SEAMLESS } from './constants';
@ -36,7 +35,7 @@ export const addGenerationTabSeamless = (
seamless_y,
});
MetadataUtil.add(g, {
g.upsertMetadata({
seamless_x: seamless_x || undefined,
seamless_y: seamless_y || undefined,
});

View File

@ -10,7 +10,6 @@ import { addGenerationTabWatermarker } from 'features/nodes/util/graph/addGenera
import type { GraphType } from 'features/nodes/util/graph/Graph';
import { Graph } from 'features/nodes/util/graph/Graph';
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
import type { Invocation } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
@ -128,7 +127,7 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<GraphTy
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
MetadataUtil.add(g, {
g.upsertMetadata({
generation_mode: 'txt2img',
cfg_scale,
cfg_rescale_multiplier,
@ -182,6 +181,6 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<GraphTy
imageOutput = addGenerationTabWatermarker(g, imageOutput);
}
MetadataUtil.setMetadataReceivingNode(g, imageOutput);
g.setMetadataReceivingNode(imageOutput);
return g.getGraph();
};

View File

@ -7,7 +7,6 @@ import { addGenerationTabSDXLRefiner } from 'features/nodes/util/graph/addGenera
import { addGenerationTabSeamless } from 'features/nodes/util/graph/addGenerationTabSeamless';
import { addGenerationTabWatermarker } from 'features/nodes/util/graph/addGenerationTabWatermarker';
import { Graph } from 'features/nodes/util/graph/Graph';
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
import type { Invocation, NonNullableGraph } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
@ -119,7 +118,7 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
MetadataUtil.add(g, {
g.upsertMetadata({
generation_mode: 'txt2img',
cfg_scale,
cfg_rescale_multiplier,
@ -173,6 +172,6 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
imageOutput = addGenerationTabWatermarker(g, imageOutput);
}
MetadataUtil.setMetadataReceivingNode(g, imageOutput);
g.setMetadataReceivingNode(imageOutput);
return g.getGraph();
};