diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/MetadataUtil.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/MetadataUtil.test.ts deleted file mode 100644 index 69e3676641..0000000000 --- a/invokeai/frontend/web/src/features/nodes/util/graph/MetadataUtil.test.ts +++ /dev/null @@ -1,91 +0,0 @@ -import { isModelIdentifier } from 'features/nodes/types/common'; -import { Graph } from 'features/nodes/util/graph/Graph'; -import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil'; -import { pick } from 'lodash-es'; -import type { AnyModelConfig } from 'services/api/types'; -import { AssertionError } from 'tsafe'; -import { describe, expect, it } from 'vitest'; - -describe('MetadataUtil', () => { - describe('getNode', () => { - it('should return the metadata node if one exists', () => { - const g = new Graph(); - // @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing - const metadataNode = g.addNode({ id: MetadataUtil.metadataNodeId, type: 'core_metadata' }); - expect(MetadataUtil.getNode(g)).toEqual(metadataNode); - }); - it('should raise an error if the metadata node does not exist', () => { - const g = new Graph(); - expect(() => MetadataUtil.getNode(g)).toThrowError(AssertionError); - }); - }); - - describe('add', () => { - const g = new Graph(); - it("should add metadata, creating the node if it doesn't exist", () => { - MetadataUtil.add(g, { foo: 'bar' }); - const metadataNode = MetadataUtil.getNode(g); - expect(metadataNode['type']).toBe('core_metadata'); - expect(metadataNode['foo']).toBe('bar'); - }); - it('should update existing metadata keys', () => { - const updatedMetadataNode = MetadataUtil.add(g, { foo: 'bananas', baz: 'qux' }); - expect(updatedMetadataNode['foo']).toBe('bananas'); - expect(updatedMetadataNode['baz']).toBe('qux'); - }); - }); - - describe('remove', () => { - it('should remove a single key', () => { - const g = new Graph(); - MetadataUtil.add(g, { foo: 'bar', baz: 'qux' }); - const updatedMetadataNode = MetadataUtil.remove(g, 'foo'); - expect(updatedMetadataNode['foo']).toBeUndefined(); - expect(updatedMetadataNode['baz']).toBe('qux'); - }); - it('should remove multiple keys', () => { - const g = new Graph(); - MetadataUtil.add(g, { foo: 'bar', baz: 'qux' }); - const updatedMetadataNode = MetadataUtil.remove(g, ['foo', 'baz']); - expect(updatedMetadataNode['foo']).toBeUndefined(); - expect(updatedMetadataNode['baz']).toBeUndefined(); - }); - }); - - describe('setMetadataReceivingNode', () => { - const g = new Graph(); - it('should add an edge from from metadata to the receiving node', () => { - const n = g.addNode({ id: 'my-node', type: 'img_resize' }); - MetadataUtil.add(g, { foo: 'bar' }); - MetadataUtil.setMetadataReceivingNode(g, n); - // @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing - expect(g.hasEdge(MetadataUtil.getNode(g), 'metadata', n, 'metadata')).toBe(true); - }); - it('should remove existing metadata edges', () => { - const n2 = g.addNode({ id: 'my-other-node', type: 'img_resize' }); - MetadataUtil.setMetadataReceivingNode(g, n2); - expect(g.getIncomers(n2).length).toBe(1); - // @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing - expect(g.hasEdge(MetadataUtil.getNode(g), 'metadata', n2, 'metadata')).toBe(true); - }); - }); - - describe('getModelMetadataField', () => { - it('should return a ModelIdentifierField', () => { - const model: AnyModelConfig = { - key: 'model_key', - type: 'main', - hash: 'model_hash', - base: 'sd-1', - format: 'diffusers', - name: 'my model', - path: '/some/path', - source: 'www.models.com', - source_type: 'url', - }; - const metadataField = MetadataUtil.getModelMetadataField(model); - expect(isModelIdentifier(metadataField)).toBe(true); - expect(pick(model, ['key', 'hash', 'name', 'base', 'type'])).toEqual(metadataField); - }); - }); -}); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/MetadataUtil.ts b/invokeai/frontend/web/src/features/nodes/util/graph/MetadataUtil.ts deleted file mode 100644 index 38e57a5e65..0000000000 --- a/invokeai/frontend/web/src/features/nodes/util/graph/MetadataUtil.ts +++ /dev/null @@ -1,74 +0,0 @@ -import type { ModelIdentifierField } from 'features/nodes/types/common'; -import { METADATA } from 'features/nodes/util/graph/constants'; -import { isString, unset } from 'lodash-es'; -import type { - AnyInvocation, - AnyInvocationIncMetadata, - AnyModelConfig, - CoreMetadataInvocation, - S, -} from 'services/api/types'; -import { assert } from 'tsafe'; - -import type { Graph } from './Graph'; - -const isCoreMetadata = (node: S['Graph']['nodes'][string]): node is CoreMetadataInvocation => - node.type === 'core_metadata'; - -export class MetadataUtil { - static metadataNodeId = METADATA; - - static getNode(g: Graph): CoreMetadataInvocation { - const node = g.getNode(this.metadataNodeId) as AnyInvocationIncMetadata; - assert(isCoreMetadata(node)); - return node; - } - - static add(g: Graph, metadata: Partial): CoreMetadataInvocation { - try { - const node = g.getNode(this.metadataNodeId) as AnyInvocationIncMetadata; - assert(isCoreMetadata(node)); - Object.assign(node, metadata); - return node; - } catch { - const metadataNode: CoreMetadataInvocation = { - id: this.metadataNodeId, - type: 'core_metadata', - ...metadata, - }; - // @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing - return g.addNode(metadataNode); - } - } - - static remove(g: Graph, key: string): CoreMetadataInvocation; - static remove(g: Graph, keys: string[]): CoreMetadataInvocation; - static remove(g: Graph, keyOrKeys: string | string[]): CoreMetadataInvocation { - const metadataNode = this.getNode(g); - if (isString(keyOrKeys)) { - unset(metadataNode, keyOrKeys); - } else { - for (const key of keyOrKeys) { - unset(metadataNode, key); - } - } - return metadataNode; - } - - static setMetadataReceivingNode(g: Graph, node: AnyInvocation): void { - // @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing - g.deleteEdgesFrom(this.getNode(g)); - // @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing - g.addEdge(this.getNode(g), 'metadata', node, 'metadata'); - } - - static getModelMetadataField({ key, hash, name, base, type }: AnyModelConfig): ModelIdentifierField { - return { - key, - hash, - name, - base, - type, - }; - } -} diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabControlLayers.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabControlLayers.ts index 3c7c0c9c66..6d890a29e2 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabControlLayers.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabControlLayers.ts @@ -31,7 +31,6 @@ import { T2I_ADAPTER_COLLECT, } from 'features/nodes/util/graph/constants'; import type { Graph } from 'features/nodes/util/graph/Graph'; -import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil'; import { size } from 'lodash-es'; import { getImageDTO, imagesApi } from 'services/api/endpoints/images'; import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types'; @@ -245,7 +244,7 @@ export const addGenerationTabControlLayers = async ( } } - MetadataUtil.add(g, { control_layers: { layers: validLayers, version: state.controlLayers.present._version } }); + g.upsertMetadata({ control_layers: { layers: validLayers, version: state.controlLayers.present._version } }); return validLayers; }; @@ -490,7 +489,7 @@ const addInitialImageLayerToGraph = ( g.addEdge(i2l, 'height', noise, 'height'); } - MetadataUtil.add(g, { generation_mode: isSDXL ? 'sdxl_img2img' : 'img2img' }); + g.upsertMetadata({ generation_mode: isSDXL ? 'sdxl_img2img' : 'img2img' }); }; const isValidControlAdapter = (ca: ControlNetConfigV2 | T2IAdapterConfigV2, base: BaseModelType): boolean => { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabHRF.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabHRF.ts index 7d1b20d018..b9adf6b8bb 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabHRF.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabHRF.ts @@ -3,7 +3,6 @@ import { deepClone } from 'common/util/deepClone'; import { roundToMultiple } from 'common/util/roundDownToMultiple'; import type { Graph } from 'features/nodes/util/graph/Graph'; import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; -import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil'; import { selectOptimalDimension } from 'features/parameters/store/generationSlice'; import type { Invocation } from 'services/api/types'; @@ -157,12 +156,12 @@ export const addGenerationTabHRF = ( g.addEdge(vaeSource, 'vae', l2iHrfHR, 'vae'); g.addEdge(denoiseHrf, 'latents', l2iHrfHR, 'latents'); - MetadataUtil.add(g, { + g.upsertMetadata({ hrf_strength: hrfStrength, hrf_enabled: hrfEnabled, hrf_method: hrfMethod, }); - MetadataUtil.setMetadataReceivingNode(g, l2iHrfHR); + g.setMetadataReceivingNode(l2iHrfHR); return l2iHrfHR; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabLoRAs.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabLoRAs.ts index 6374c72a93..ee812468b8 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabLoRAs.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabLoRAs.ts @@ -1,7 +1,6 @@ import type { RootState } from 'app/store/store'; import { zModelIdentifierField } from 'features/nodes/types/common'; import type { Graph } from 'features/nodes/util/graph/Graph'; -import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil'; import { filter, size } from 'lodash-es'; import type { Invocation, S } from 'services/api/types'; @@ -69,5 +68,5 @@ export const addGenerationTabLoRAs = ( g.addEdge(loraSelector, 'lora', loraCollector, 'item'); } - MetadataUtil.add(g, { loras: loraMetadata }); + g.upsertMetadata({ loras: loraMetadata }); }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabSDXLLoRAs.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabSDXLLoRAs.ts index 89f1f8f18e..bbd16e8f53 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabSDXLLoRAs.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabSDXLLoRAs.ts @@ -1,7 +1,6 @@ import type { RootState } from 'app/store/store'; import { zModelIdentifierField } from 'features/nodes/types/common'; import type { Graph } from 'features/nodes/util/graph/Graph'; -import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil'; import { filter, size } from 'lodash-es'; import type { Invocation, S } from 'services/api/types'; @@ -71,5 +70,5 @@ export const addGenerationTabSDXLLoRAs = ( g.addEdge(loraSelector, 'lora', loraCollector, 'item'); } - MetadataUtil.add(g, { loras: loraMetadata }); + g.upsertMetadata({ loras: loraMetadata }); }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabSDXLRefiner.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabSDXLRefiner.ts index 7c207b75bb..0cbb637d03 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabSDXLRefiner.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabSDXLRefiner.ts @@ -1,7 +1,6 @@ import type { RootState } from 'app/store/store'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import type { Graph } from 'features/nodes/util/graph/Graph'; -import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil'; import type { Invocation } from 'services/api/types'; import { isRefinerMainModelModelConfig } from 'services/api/types'; import { assert } from 'tsafe'; @@ -92,7 +91,7 @@ export const addGenerationTabSDXLRefiner = async ( g.addEdge(denoise, 'latents', refinerDenoise, 'latents'); g.addEdge(refinerDenoise, 'latents', l2i, 'latents'); - MetadataUtil.add(g, { + g.upsertMetadata({ refiner_model: getModelMetadataField(modelConfig), refiner_positive_aesthetic_score: refinerPositiveAestheticScore, refiner_negative_aesthetic_score: refinerNegativeAestheticScore, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabSeamless.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabSeamless.ts index 709ba1416c..a3303e6c6f 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabSeamless.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabSeamless.ts @@ -1,6 +1,5 @@ import type { RootState } from 'app/store/store'; import type { Graph } from 'features/nodes/util/graph/Graph'; -import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil'; import type { Invocation } from 'services/api/types'; import { SEAMLESS } from './constants'; @@ -36,7 +35,7 @@ export const addGenerationTabSeamless = ( seamless_y, }); - MetadataUtil.add(g, { + g.upsertMetadata({ seamless_x: seamless_x || undefined, seamless_y: seamless_y || undefined, }); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildGenerationTabGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildGenerationTabGraph.ts index a50e722e43..c24353cc40 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildGenerationTabGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildGenerationTabGraph.ts @@ -10,7 +10,6 @@ import { addGenerationTabWatermarker } from 'features/nodes/util/graph/addGenera import type { GraphType } from 'features/nodes/util/graph/Graph'; import { Graph } from 'features/nodes/util/graph/Graph'; import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils'; -import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil'; import type { Invocation } from 'services/api/types'; import { isNonRefinerMainModelConfig } from 'services/api/types'; import { assert } from 'tsafe'; @@ -128,7 +127,7 @@ export const buildGenerationTabGraph = async (state: RootState): Promise