feat(ui): move metadata util to graph class

No good reason to have it be separate. A bit cleaner this way.
This commit is contained in:
psychedelicious 2024-05-14 13:41:17 +10:00
parent 154b52ca4d
commit ee647a05dc
2 changed files with 173 additions and 3 deletions

View File

@ -1,5 +1,5 @@
import { Graph } from 'features/nodes/util/graph/Graph';
import type { Invocation } from 'services/api/types';
import type { AnyInvocation, Invocation } from 'services/api/types';
import { assert, AssertionError, is } from 'tsafe';
import { validate } from 'uuid';
import { describe, expect, it } from 'vitest';
@ -414,4 +414,106 @@ describe('Graph', () => {
expect(g.getEdgesTo(n3)).toEqual([e2]);
});
});
describe('metadata utils', () => {
describe('_getMetadataNode', () => {
it("should get the metadata node, creating it if it doesn't exist", () => {
const g = new Graph();
const metadata = g._getMetadataNode();
expect(metadata.id).toBe('core_metadata');
expect(metadata.type).toBe('core_metadata');
g.upsertMetadata({ test: 'test' });
const metadata2 = g._getMetadataNode();
expect(metadata2).toHaveProperty('test');
});
});
describe('upsertMetadata', () => {
it('should add metadata to the metadata node', () => {
const g = new Graph();
g.upsertMetadata({ test: 'test' });
const metadata = g._getMetadataNode();
expect(metadata).toHaveProperty('test');
});
it('should update metadata on the metadata node', () => {
const g = new Graph();
g.upsertMetadata({ test: 'test' });
g.upsertMetadata({ test: 'test2' });
const metadata = g._getMetadataNode();
expect(metadata.test).toBe('test2');
});
});
describe('removeMetadata', () => {
it('should remove metadata from the metadata node', () => {
const g = new Graph();
g.upsertMetadata({ test: 'test', test2: 'test2' });
g.removeMetadata(['test']);
const metadata = g._getMetadataNode();
expect(metadata).not.toHaveProperty('test');
});
it('should remove multiple metadata from the metadata node', () => {
const g = new Graph();
g.upsertMetadata({ test: 'test', test2: 'test2' });
g.removeMetadata(['test', 'test2']);
const metadata = g._getMetadataNode();
expect(metadata).not.toHaveProperty('test');
expect(metadata).not.toHaveProperty('test2');
});
});
describe('setMetadataReceivingNode', () => {
it('should set the metadata receiving node', () => {
const g = new Graph();
const n1 = g.addNode({
id: 'n1',
type: 'img_resize',
});
g.upsertMetadata({ test: 'test' });
g.setMetadataReceivingNode(n1);
const metadata = g._getMetadataNode();
expect(g.getEdgesFrom(metadata as unknown as AnyInvocation).length).toBe(1);
expect(g.getEdgesTo(n1).length).toBe(1);
});
});
describe('getModelMetadataField', () => {
it('should return a ModelIdentifierField', () => {
const field = Graph.getModelMetadataField({
key: 'b00ee8df-523d-40d2-9578-597283b07cb2',
hash: 'random:9adf270422f525715297afa1649c4ff007a55f09937f57ca628278305624d194',
path: 'sdxl/main/stable-diffusion-xl-1.0-inpainting-0.1',
name: 'stable-diffusion-xl-1.0-inpainting-0.1',
base: 'sdxl',
description: 'sdxl main model stable-diffusion-xl-1.0-inpainting-0.1',
source: '/home/bat/invokeai-4.0.0/models/sdxl/main/stable-diffusion-xl-1.0-inpainting-0.1',
source_type: 'path',
source_api_response: null,
cover_image: null,
type: 'main',
trigger_phrases: null,
default_settings: {
vae: null,
vae_precision: null,
scheduler: null,
steps: null,
cfg_scale: null,
cfg_rescale_multiplier: null,
width: 1024,
height: 1024,
},
variant: 'inpaint',
format: 'diffusers',
repo_variant: 'fp16',
});
expect(field).toEqual({
key: 'b00ee8df-523d-40d2-9578-597283b07cb2',
hash: 'random:9adf270422f525715297afa1649c4ff007a55f09937f57ca628278305624d194',
name: 'stable-diffusion-xl-1.0-inpainting-0.1',
base: 'sdxl',
type: 'main',
});
});
});
});
});

View File

@ -1,8 +1,13 @@
import { forEach, groupBy, isEqual, values } from 'lodash-es';
import { type ModelIdentifierField, zModelIdentifierField } from 'features/nodes/types/common';
import { METADATA } from 'features/nodes/util/graph/constants';
import { forEach, groupBy, isEqual, unset, values } from 'lodash-es';
import type {
AnyInvocation,
AnyInvocationIncMetadata,
AnyInvocationInputField,
AnyInvocationOutputField,
AnyModelConfig,
CoreMetadataInvocation,
InputFields,
Invocation,
InvocationType,
@ -332,8 +337,71 @@ export class Graph {
}
//#endregion
//#region Util
//#region Metadata
/**
* INTERNAL: Get the metadata node. If it does not exist, it is created.
* @returns The metadata node.
*/
_getMetadataNode(): CoreMetadataInvocation {
try {
const node = this.getNode(METADATA) as AnyInvocationIncMetadata;
assert(node.type === 'core_metadata');
return node;
} catch {
const node: CoreMetadataInvocation = { id: METADATA, type: 'core_metadata' };
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
return this.addNode(node);
}
}
/**
* Add metadata to the graph. If the metadata node does not exist, it is created. If the specific metadata key exists,
* it is overwritten.
* @param metadata The metadata to add.
* @returns The metadata node.
*/
upsertMetadata(metadata: Partial<CoreMetadataInvocation>): CoreMetadataInvocation {
const node = this._getMetadataNode();
Object.assign(node, metadata);
return node;
}
/**
* Remove metadata from the graph.
* @param keys The keys of the metadata to remove
* @returns The metadata node
*/
removeMetadata(keys: string[]): CoreMetadataInvocation {
const metadataNode = this._getMetadataNode();
for (const k of keys) {
unset(metadataNode, k);
}
return metadataNode;
}
/**
* Set the node that should receive metadata. All other edges from the metadata node are deleted.
* @param node The node to set as the receiving node
*/
setMetadataReceivingNode(node: AnyInvocation): void {
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
this.deleteEdgesFrom(this._getMetadataNode());
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
this.addEdge(this._getMetadataNode(), 'metadata', node, 'metadata');
}
/**
* Given a model config, return the model metadata field.
* @param modelConfig The model config entity
* @returns
*/
static getModelMetadataField(modelConfig: AnyModelConfig): ModelIdentifierField {
return zModelIdentifierField.parse(modelConfig);
}
//#endregion
//#region Util
static getNodeNotFoundMsg(id: string): string {
return `Node ${id} not found`;
}