mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): use integrated metadata helper
This commit is contained in:
parent
ee647a05dc
commit
48ccd63dba
@ -1,91 +0,0 @@
|
|||||||
import { isModelIdentifier } from 'features/nodes/types/common';
|
|
||||||
import { Graph } from 'features/nodes/util/graph/Graph';
|
|
||||||
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
|
||||||
import { pick } from 'lodash-es';
|
|
||||||
import type { AnyModelConfig } from 'services/api/types';
|
|
||||||
import { AssertionError } from 'tsafe';
|
|
||||||
import { describe, expect, it } from 'vitest';
|
|
||||||
|
|
||||||
describe('MetadataUtil', () => {
|
|
||||||
describe('getNode', () => {
|
|
||||||
it('should return the metadata node if one exists', () => {
|
|
||||||
const g = new Graph();
|
|
||||||
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
|
|
||||||
const metadataNode = g.addNode({ id: MetadataUtil.metadataNodeId, type: 'core_metadata' });
|
|
||||||
expect(MetadataUtil.getNode(g)).toEqual(metadataNode);
|
|
||||||
});
|
|
||||||
it('should raise an error if the metadata node does not exist', () => {
|
|
||||||
const g = new Graph();
|
|
||||||
expect(() => MetadataUtil.getNode(g)).toThrowError(AssertionError);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('add', () => {
|
|
||||||
const g = new Graph();
|
|
||||||
it("should add metadata, creating the node if it doesn't exist", () => {
|
|
||||||
MetadataUtil.add(g, { foo: 'bar' });
|
|
||||||
const metadataNode = MetadataUtil.getNode(g);
|
|
||||||
expect(metadataNode['type']).toBe('core_metadata');
|
|
||||||
expect(metadataNode['foo']).toBe('bar');
|
|
||||||
});
|
|
||||||
it('should update existing metadata keys', () => {
|
|
||||||
const updatedMetadataNode = MetadataUtil.add(g, { foo: 'bananas', baz: 'qux' });
|
|
||||||
expect(updatedMetadataNode['foo']).toBe('bananas');
|
|
||||||
expect(updatedMetadataNode['baz']).toBe('qux');
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('remove', () => {
|
|
||||||
it('should remove a single key', () => {
|
|
||||||
const g = new Graph();
|
|
||||||
MetadataUtil.add(g, { foo: 'bar', baz: 'qux' });
|
|
||||||
const updatedMetadataNode = MetadataUtil.remove(g, 'foo');
|
|
||||||
expect(updatedMetadataNode['foo']).toBeUndefined();
|
|
||||||
expect(updatedMetadataNode['baz']).toBe('qux');
|
|
||||||
});
|
|
||||||
it('should remove multiple keys', () => {
|
|
||||||
const g = new Graph();
|
|
||||||
MetadataUtil.add(g, { foo: 'bar', baz: 'qux' });
|
|
||||||
const updatedMetadataNode = MetadataUtil.remove(g, ['foo', 'baz']);
|
|
||||||
expect(updatedMetadataNode['foo']).toBeUndefined();
|
|
||||||
expect(updatedMetadataNode['baz']).toBeUndefined();
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('setMetadataReceivingNode', () => {
|
|
||||||
const g = new Graph();
|
|
||||||
it('should add an edge from from metadata to the receiving node', () => {
|
|
||||||
const n = g.addNode({ id: 'my-node', type: 'img_resize' });
|
|
||||||
MetadataUtil.add(g, { foo: 'bar' });
|
|
||||||
MetadataUtil.setMetadataReceivingNode(g, n);
|
|
||||||
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
|
|
||||||
expect(g.hasEdge(MetadataUtil.getNode(g), 'metadata', n, 'metadata')).toBe(true);
|
|
||||||
});
|
|
||||||
it('should remove existing metadata edges', () => {
|
|
||||||
const n2 = g.addNode({ id: 'my-other-node', type: 'img_resize' });
|
|
||||||
MetadataUtil.setMetadataReceivingNode(g, n2);
|
|
||||||
expect(g.getIncomers(n2).length).toBe(1);
|
|
||||||
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
|
|
||||||
expect(g.hasEdge(MetadataUtil.getNode(g), 'metadata', n2, 'metadata')).toBe(true);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('getModelMetadataField', () => {
|
|
||||||
it('should return a ModelIdentifierField', () => {
|
|
||||||
const model: AnyModelConfig = {
|
|
||||||
key: 'model_key',
|
|
||||||
type: 'main',
|
|
||||||
hash: 'model_hash',
|
|
||||||
base: 'sd-1',
|
|
||||||
format: 'diffusers',
|
|
||||||
name: 'my model',
|
|
||||||
path: '/some/path',
|
|
||||||
source: 'www.models.com',
|
|
||||||
source_type: 'url',
|
|
||||||
};
|
|
||||||
const metadataField = MetadataUtil.getModelMetadataField(model);
|
|
||||||
expect(isModelIdentifier(metadataField)).toBe(true);
|
|
||||||
expect(pick(model, ['key', 'hash', 'name', 'base', 'type'])).toEqual(metadataField);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
@ -1,74 +0,0 @@
|
|||||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
|
||||||
import { METADATA } from 'features/nodes/util/graph/constants';
|
|
||||||
import { isString, unset } from 'lodash-es';
|
|
||||||
import type {
|
|
||||||
AnyInvocation,
|
|
||||||
AnyInvocationIncMetadata,
|
|
||||||
AnyModelConfig,
|
|
||||||
CoreMetadataInvocation,
|
|
||||||
S,
|
|
||||||
} from 'services/api/types';
|
|
||||||
import { assert } from 'tsafe';
|
|
||||||
|
|
||||||
import type { Graph } from './Graph';
|
|
||||||
|
|
||||||
const isCoreMetadata = (node: S['Graph']['nodes'][string]): node is CoreMetadataInvocation =>
|
|
||||||
node.type === 'core_metadata';
|
|
||||||
|
|
||||||
export class MetadataUtil {
|
|
||||||
static metadataNodeId = METADATA;
|
|
||||||
|
|
||||||
static getNode(g: Graph): CoreMetadataInvocation {
|
|
||||||
const node = g.getNode(this.metadataNodeId) as AnyInvocationIncMetadata;
|
|
||||||
assert(isCoreMetadata(node));
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
|
|
||||||
static add(g: Graph, metadata: Partial<CoreMetadataInvocation>): CoreMetadataInvocation {
|
|
||||||
try {
|
|
||||||
const node = g.getNode(this.metadataNodeId) as AnyInvocationIncMetadata;
|
|
||||||
assert(isCoreMetadata(node));
|
|
||||||
Object.assign(node, metadata);
|
|
||||||
return node;
|
|
||||||
} catch {
|
|
||||||
const metadataNode: CoreMetadataInvocation = {
|
|
||||||
id: this.metadataNodeId,
|
|
||||||
type: 'core_metadata',
|
|
||||||
...metadata,
|
|
||||||
};
|
|
||||||
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
|
|
||||||
return g.addNode(metadataNode);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static remove(g: Graph, key: string): CoreMetadataInvocation;
|
|
||||||
static remove(g: Graph, keys: string[]): CoreMetadataInvocation;
|
|
||||||
static remove(g: Graph, keyOrKeys: string | string[]): CoreMetadataInvocation {
|
|
||||||
const metadataNode = this.getNode(g);
|
|
||||||
if (isString(keyOrKeys)) {
|
|
||||||
unset(metadataNode, keyOrKeys);
|
|
||||||
} else {
|
|
||||||
for (const key of keyOrKeys) {
|
|
||||||
unset(metadataNode, key);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return metadataNode;
|
|
||||||
}
|
|
||||||
|
|
||||||
static setMetadataReceivingNode(g: Graph, node: AnyInvocation): void {
|
|
||||||
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
|
|
||||||
g.deleteEdgesFrom(this.getNode(g));
|
|
||||||
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
|
|
||||||
g.addEdge(this.getNode(g), 'metadata', node, 'metadata');
|
|
||||||
}
|
|
||||||
|
|
||||||
static getModelMetadataField({ key, hash, name, base, type }: AnyModelConfig): ModelIdentifierField {
|
|
||||||
return {
|
|
||||||
key,
|
|
||||||
hash,
|
|
||||||
name,
|
|
||||||
base,
|
|
||||||
type,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
@ -31,7 +31,6 @@ import {
|
|||||||
T2I_ADAPTER_COLLECT,
|
T2I_ADAPTER_COLLECT,
|
||||||
} from 'features/nodes/util/graph/constants';
|
} from 'features/nodes/util/graph/constants';
|
||||||
import type { Graph } from 'features/nodes/util/graph/Graph';
|
import type { Graph } from 'features/nodes/util/graph/Graph';
|
||||||
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
|
||||||
import { size } from 'lodash-es';
|
import { size } from 'lodash-es';
|
||||||
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
|
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
|
||||||
import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types';
|
import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types';
|
||||||
@ -245,7 +244,7 @@ export const addGenerationTabControlLayers = async (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MetadataUtil.add(g, { control_layers: { layers: validLayers, version: state.controlLayers.present._version } });
|
g.upsertMetadata({ control_layers: { layers: validLayers, version: state.controlLayers.present._version } });
|
||||||
return validLayers;
|
return validLayers;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -490,7 +489,7 @@ const addInitialImageLayerToGraph = (
|
|||||||
g.addEdge(i2l, 'height', noise, 'height');
|
g.addEdge(i2l, 'height', noise, 'height');
|
||||||
}
|
}
|
||||||
|
|
||||||
MetadataUtil.add(g, { generation_mode: isSDXL ? 'sdxl_img2img' : 'img2img' });
|
g.upsertMetadata({ generation_mode: isSDXL ? 'sdxl_img2img' : 'img2img' });
|
||||||
};
|
};
|
||||||
|
|
||||||
const isValidControlAdapter = (ca: ControlNetConfigV2 | T2IAdapterConfigV2, base: BaseModelType): boolean => {
|
const isValidControlAdapter = (ca: ControlNetConfigV2 | T2IAdapterConfigV2, base: BaseModelType): boolean => {
|
||||||
|
@ -3,7 +3,6 @@ import { deepClone } from 'common/util/deepClone';
|
|||||||
import { roundToMultiple } from 'common/util/roundDownToMultiple';
|
import { roundToMultiple } from 'common/util/roundDownToMultiple';
|
||||||
import type { Graph } from 'features/nodes/util/graph/Graph';
|
import type { Graph } from 'features/nodes/util/graph/Graph';
|
||||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||||
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
|
||||||
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
||||||
import type { Invocation } from 'services/api/types';
|
import type { Invocation } from 'services/api/types';
|
||||||
|
|
||||||
@ -157,12 +156,12 @@ export const addGenerationTabHRF = (
|
|||||||
g.addEdge(vaeSource, 'vae', l2iHrfHR, 'vae');
|
g.addEdge(vaeSource, 'vae', l2iHrfHR, 'vae');
|
||||||
g.addEdge(denoiseHrf, 'latents', l2iHrfHR, 'latents');
|
g.addEdge(denoiseHrf, 'latents', l2iHrfHR, 'latents');
|
||||||
|
|
||||||
MetadataUtil.add(g, {
|
g.upsertMetadata({
|
||||||
hrf_strength: hrfStrength,
|
hrf_strength: hrfStrength,
|
||||||
hrf_enabled: hrfEnabled,
|
hrf_enabled: hrfEnabled,
|
||||||
hrf_method: hrfMethod,
|
hrf_method: hrfMethod,
|
||||||
});
|
});
|
||||||
MetadataUtil.setMetadataReceivingNode(g, l2iHrfHR);
|
g.setMetadataReceivingNode(l2iHrfHR);
|
||||||
|
|
||||||
return l2iHrfHR;
|
return l2iHrfHR;
|
||||||
};
|
};
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||||
import type { Graph } from 'features/nodes/util/graph/Graph';
|
import type { Graph } from 'features/nodes/util/graph/Graph';
|
||||||
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
|
||||||
import { filter, size } from 'lodash-es';
|
import { filter, size } from 'lodash-es';
|
||||||
import type { Invocation, S } from 'services/api/types';
|
import type { Invocation, S } from 'services/api/types';
|
||||||
|
|
||||||
@ -69,5 +68,5 @@ export const addGenerationTabLoRAs = (
|
|||||||
g.addEdge(loraSelector, 'lora', loraCollector, 'item');
|
g.addEdge(loraSelector, 'lora', loraCollector, 'item');
|
||||||
}
|
}
|
||||||
|
|
||||||
MetadataUtil.add(g, { loras: loraMetadata });
|
g.upsertMetadata({ loras: loraMetadata });
|
||||||
};
|
};
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||||
import type { Graph } from 'features/nodes/util/graph/Graph';
|
import type { Graph } from 'features/nodes/util/graph/Graph';
|
||||||
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
|
||||||
import { filter, size } from 'lodash-es';
|
import { filter, size } from 'lodash-es';
|
||||||
import type { Invocation, S } from 'services/api/types';
|
import type { Invocation, S } from 'services/api/types';
|
||||||
|
|
||||||
@ -71,5 +70,5 @@ export const addGenerationTabSDXLLoRAs = (
|
|||||||
g.addEdge(loraSelector, 'lora', loraCollector, 'item');
|
g.addEdge(loraSelector, 'lora', loraCollector, 'item');
|
||||||
}
|
}
|
||||||
|
|
||||||
MetadataUtil.add(g, { loras: loraMetadata });
|
g.upsertMetadata({ loras: loraMetadata });
|
||||||
};
|
};
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import type { Graph } from 'features/nodes/util/graph/Graph';
|
import type { Graph } from 'features/nodes/util/graph/Graph';
|
||||||
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
|
||||||
import type { Invocation } from 'services/api/types';
|
import type { Invocation } from 'services/api/types';
|
||||||
import { isRefinerMainModelModelConfig } from 'services/api/types';
|
import { isRefinerMainModelModelConfig } from 'services/api/types';
|
||||||
import { assert } from 'tsafe';
|
import { assert } from 'tsafe';
|
||||||
@ -92,7 +91,7 @@ export const addGenerationTabSDXLRefiner = async (
|
|||||||
g.addEdge(denoise, 'latents', refinerDenoise, 'latents');
|
g.addEdge(denoise, 'latents', refinerDenoise, 'latents');
|
||||||
g.addEdge(refinerDenoise, 'latents', l2i, 'latents');
|
g.addEdge(refinerDenoise, 'latents', l2i, 'latents');
|
||||||
|
|
||||||
MetadataUtil.add(g, {
|
g.upsertMetadata({
|
||||||
refiner_model: getModelMetadataField(modelConfig),
|
refiner_model: getModelMetadataField(modelConfig),
|
||||||
refiner_positive_aesthetic_score: refinerPositiveAestheticScore,
|
refiner_positive_aesthetic_score: refinerPositiveAestheticScore,
|
||||||
refiner_negative_aesthetic_score: refinerNegativeAestheticScore,
|
refiner_negative_aesthetic_score: refinerNegativeAestheticScore,
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import type { Graph } from 'features/nodes/util/graph/Graph';
|
import type { Graph } from 'features/nodes/util/graph/Graph';
|
||||||
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
|
||||||
import type { Invocation } from 'services/api/types';
|
import type { Invocation } from 'services/api/types';
|
||||||
|
|
||||||
import { SEAMLESS } from './constants';
|
import { SEAMLESS } from './constants';
|
||||||
@ -36,7 +35,7 @@ export const addGenerationTabSeamless = (
|
|||||||
seamless_y,
|
seamless_y,
|
||||||
});
|
});
|
||||||
|
|
||||||
MetadataUtil.add(g, {
|
g.upsertMetadata({
|
||||||
seamless_x: seamless_x || undefined,
|
seamless_x: seamless_x || undefined,
|
||||||
seamless_y: seamless_y || undefined,
|
seamless_y: seamless_y || undefined,
|
||||||
});
|
});
|
||||||
|
@ -10,7 +10,6 @@ import { addGenerationTabWatermarker } from 'features/nodes/util/graph/addGenera
|
|||||||
import type { GraphType } from 'features/nodes/util/graph/Graph';
|
import type { GraphType } from 'features/nodes/util/graph/Graph';
|
||||||
import { Graph } from 'features/nodes/util/graph/Graph';
|
import { Graph } from 'features/nodes/util/graph/Graph';
|
||||||
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
|
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||||
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
|
||||||
import type { Invocation } from 'services/api/types';
|
import type { Invocation } from 'services/api/types';
|
||||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||||
import { assert } from 'tsafe';
|
import { assert } from 'tsafe';
|
||||||
@ -128,7 +127,7 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<GraphTy
|
|||||||
|
|
||||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||||
|
|
||||||
MetadataUtil.add(g, {
|
g.upsertMetadata({
|
||||||
generation_mode: 'txt2img',
|
generation_mode: 'txt2img',
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
cfg_rescale_multiplier,
|
cfg_rescale_multiplier,
|
||||||
@ -182,6 +181,6 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<GraphTy
|
|||||||
imageOutput = addGenerationTabWatermarker(g, imageOutput);
|
imageOutput = addGenerationTabWatermarker(g, imageOutput);
|
||||||
}
|
}
|
||||||
|
|
||||||
MetadataUtil.setMetadataReceivingNode(g, imageOutput);
|
g.setMetadataReceivingNode(imageOutput);
|
||||||
return g.getGraph();
|
return g.getGraph();
|
||||||
};
|
};
|
||||||
|
@ -7,7 +7,6 @@ import { addGenerationTabSDXLRefiner } from 'features/nodes/util/graph/addGenera
|
|||||||
import { addGenerationTabSeamless } from 'features/nodes/util/graph/addGenerationTabSeamless';
|
import { addGenerationTabSeamless } from 'features/nodes/util/graph/addGenerationTabSeamless';
|
||||||
import { addGenerationTabWatermarker } from 'features/nodes/util/graph/addGenerationTabWatermarker';
|
import { addGenerationTabWatermarker } from 'features/nodes/util/graph/addGenerationTabWatermarker';
|
||||||
import { Graph } from 'features/nodes/util/graph/Graph';
|
import { Graph } from 'features/nodes/util/graph/Graph';
|
||||||
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
|
||||||
import type { Invocation, NonNullableGraph } from 'services/api/types';
|
import type { Invocation, NonNullableGraph } from 'services/api/types';
|
||||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||||
import { assert } from 'tsafe';
|
import { assert } from 'tsafe';
|
||||||
@ -119,7 +118,7 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
|
|||||||
|
|
||||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||||
|
|
||||||
MetadataUtil.add(g, {
|
g.upsertMetadata({
|
||||||
generation_mode: 'txt2img',
|
generation_mode: 'txt2img',
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
cfg_rescale_multiplier,
|
cfg_rescale_multiplier,
|
||||||
@ -173,6 +172,6 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
|
|||||||
imageOutput = addGenerationTabWatermarker(g, imageOutput);
|
imageOutput = addGenerationTabWatermarker(g, imageOutput);
|
||||||
}
|
}
|
||||||
|
|
||||||
MetadataUtil.setMetadataReceivingNode(g, imageOutput);
|
g.setMetadataReceivingNode(imageOutput);
|
||||||
return g.getGraph();
|
return g.getGraph();
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user