diff --git a/docs/installation/010_INSTALL_AUTOMATED.md b/docs/installation/010_INSTALL_AUTOMATED.md index 83b4415394..c710ed17b1 100644 --- a/docs/installation/010_INSTALL_AUTOMATED.md +++ b/docs/installation/010_INSTALL_AUTOMATED.md @@ -89,7 +89,7 @@ experimental versions later. sudo apt update sudo apt install -y software-properties-common sudo add-apt-repository -y ppa:deadsnakes/ppa - sudo apt install python3.10 python3-pip python3.10-venv + sudo apt install -y python3.10 python3-pip python3.10-venv sudo update-alternatives --install /usr/local/bin/python python /usr/bin/python3.10 3 ``` diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py new file mode 100644 index 0000000000..1fb7832031 --- /dev/null +++ b/invokeai/app/invocations/compel.py @@ -0,0 +1,245 @@ +from typing import Literal, Optional, Union +from pydantic import BaseModel, Field + +from invokeai.app.invocations.util.choose_model import choose_model +from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig + +from ...backend.util.devices import choose_torch_device, torch_dtype +from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent +from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager + +from compel import Compel +from compel.prompt_parser import ( + Blend, + CrossAttentionControlSubstitute, + FlattenedPrompt, + Fragment, +) + +from invokeai.backend.globals import Globals + + +class ConditioningField(BaseModel): + conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data") + class Config: + schema_extra = {"required": ["conditioning_name"]} + + +class CompelOutput(BaseInvocationOutput): + """Compel parser output""" + + #fmt: off + type: Literal["compel_output"] = "compel_output" + + conditioning: ConditioningField = Field(default=None, description="Conditioning") + #fmt: on + + +class CompelInvocation(BaseInvocation): + """Parse prompt using compel package to conditioning.""" + + type: Literal["compel"] = "compel" + + prompt: str = Field(default="", description="Prompt") + model: str = Field(default="", description="Model to use") + + # Schema customisation + class Config(InvocationConfig): + schema_extra = { + "ui": { + "title": "Prompt (Compel)", + "tags": ["prompt", "compel"], + "type_hints": { + "model": "model" + } + }, + } + + def invoke(self, context: InvocationContext) -> CompelOutput: + + # TODO: load without model + model = choose_model(context.services.model_manager, self.model) + pipeline = model["model"] + tokenizer = pipeline.tokenizer + text_encoder = pipeline.text_encoder + + # TODO: global? input? + #use_full_precision = precision == "float32" or precision == "autocast" + #use_full_precision = False + + # TODO: redo TI when separate model loding implemented + #textual_inversion_manager = TextualInversionManager( + # tokenizer=tokenizer, + # text_encoder=text_encoder, + # full_precision=use_full_precision, + #) + + def load_huggingface_concepts(concepts: list[str]): + pipeline.textual_inversion_manager.load_huggingface_concepts(concepts) + + # apply the concepts library to the prompt + prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers( + self.prompt, + lambda concepts: load_huggingface_concepts(concepts), + pipeline.textual_inversion_manager.get_all_trigger_strings(), + ) + + # lazy-load any deferred textual inversions. + # this might take a couple of seconds the first time a textual inversion is used. + pipeline.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms( + prompt_str + ) + + compel = Compel( + tokenizer=tokenizer, + text_encoder=text_encoder, + textual_inversion_manager=pipeline.textual_inversion_manager, + dtype_for_device_getter=torch_dtype, + truncate_long_prompts=True, # TODO: + ) + + # TODO: support legacy blend? + + prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(prompt_str) + + if getattr(Globals, "log_tokenization", False): + log_tokenization_for_prompt_object(prompt, tokenizer) + + c, options = compel.build_conditioning_tensor_for_prompt_object(prompt) + + # TODO: long prompt support + #if not self.truncate_long_prompts: + # [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc]) + + ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( + tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt), + cross_attention_control_args=options.get("cross_attention_control", None), + ) + + conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" + + # TODO: hacky but works ;D maybe rename latents somehow? + context.services.latents.set(conditioning_name, (c, ec)) + + return CompelOutput( + conditioning=ConditioningField( + conditioning_name=conditioning_name, + ), + ) + + +def get_max_token_count( + tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False +) -> int: + if type(prompt) is Blend: + blend: Blend = prompt + return max( + [ + get_max_token_count(tokenizer, c, truncate_if_too_long) + for c in blend.prompts + ] + ) + else: + return len( + get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long) + ) + + +def get_tokens_for_prompt_object( + tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True +) -> [str]: + if type(parsed_prompt) is Blend: + raise ValueError( + "Blend is not supported here - you need to get tokens for each of its .children" + ) + + text_fragments = [ + x.text + if type(x) is Fragment + else ( + " ".join([f.text for f in x.original]) + if type(x) is CrossAttentionControlSubstitute + else str(x) + ) + for x in parsed_prompt.children + ] + text = " ".join(text_fragments) + tokens = tokenizer.tokenize(text) + if truncate_if_too_long: + max_tokens_length = tokenizer.model_max_length - 2 # typically 75 + tokens = tokens[0:max_tokens_length] + return tokens + + +def log_tokenization_for_prompt_object( + p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None +): + display_label_prefix = display_label_prefix or "" + if type(p) is Blend: + blend: Blend = p + for i, c in enumerate(blend.prompts): + log_tokenization_for_prompt_object( + c, + tokenizer, + display_label_prefix=f"{display_label_prefix}(blend part {i + 1}, weight={blend.weights[i]})", + ) + elif type(p) is FlattenedPrompt: + flattened_prompt: FlattenedPrompt = p + if flattened_prompt.wants_cross_attention_control: + original_fragments = [] + edited_fragments = [] + for f in flattened_prompt.children: + if type(f) is CrossAttentionControlSubstitute: + original_fragments += f.original + edited_fragments += f.edited + else: + original_fragments.append(f) + edited_fragments.append(f) + + original_text = " ".join([x.text for x in original_fragments]) + log_tokenization_for_text( + original_text, + tokenizer, + display_label=f"{display_label_prefix}(.swap originals)", + ) + edited_text = " ".join([x.text for x in edited_fragments]) + log_tokenization_for_text( + edited_text, + tokenizer, + display_label=f"{display_label_prefix}(.swap replacements)", + ) + else: + text = " ".join([x.text for x in flattened_prompt.children]) + log_tokenization_for_text( + text, tokenizer, display_label=display_label_prefix + ) + + +def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False): + """shows how the prompt is tokenized + # usually tokens have '' to indicate end-of-word, + # but for readability it has been replaced with ' ' + """ + tokens = tokenizer.tokenize(text) + tokenized = "" + discarded = "" + usedTokens = 0 + totalTokens = len(tokens) + + for i in range(0, totalTokens): + token = tokens[i].replace("", " ") + # alternate color + s = (usedTokens % 6) + 1 + if truncate_if_too_long and i >= tokenizer.model_max_length: + discarded = discarded + f"\x1b[0;3{s};40m{token}" + else: + tokenized = tokenized + f"\x1b[0;3{s};40m{token}" + usedTokens += 1 + + if usedTokens > 0: + print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):') + print(f"{tokenized}\x1b[0m") + + if discarded != "": + print(f"\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):") + print(f"{discarded}\x1b[0m") diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index dd44ee0ed9..580df3987d 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -250,8 +250,8 @@ class InpaintInvocation(ImageToImageInvocation): outputs = Inpaint(model).generate( prompt=self.prompt, - init_img=image, - init_mask=mask, + init_image=image, + mask_image=mask, step_callback=partial(self.dispatch_progress, context, source_node_id), **self.dict( exclude={"prompt", "image", "mask"} diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index c6ea8a686a..0d3ef4a8cd 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -13,13 +13,13 @@ from ...backend.model_management.model_manager import ModelManager from ...backend.util.devices import choose_torch_device, torch_dtype from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.image_util.seamless import configure_model_padding -from ...backend.prompting.conditioning import get_uc_and_c_and_ec from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig import numpy as np from ..services.image_storage import ImageType from .baseinvocation import BaseInvocation, InvocationContext from .image import ImageField, ImageOutput, build_image_output +from .compel import ConditioningField from ...backend.stable_diffusion import PipelineIntermediateState from diffusers.schedulers import SchedulerMixin as Scheduler import diffusers @@ -138,14 +138,14 @@ class NoiseInvocation(BaseInvocation): # Text to image class TextToLatentsInvocation(BaseInvocation): - """Generates latents from a prompt.""" + """Generates latents from conditionings.""" type: Literal["t2l"] = "t2l" # Inputs - # TODO: consider making prompt optional to enable providing prompt through a link # fmt: off - prompt: Optional[str] = Field(description="The prompt to generate an image from") + positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation") + negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation") noise: Optional[LatentsField] = Field(description="The noise to use") steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) @@ -203,8 +203,10 @@ class TextToLatentsInvocation(BaseInvocation): return model - def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData: - uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model) + def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData: + c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name) + uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) + conditioning_data = ConditioningData( uc, c, @@ -231,7 +233,7 @@ class TextToLatentsInvocation(BaseInvocation): self.dispatch_progress(context, source_node_id, state) model = self.get_model(context.services.model_manager) - conditioning_data = self.get_conditioning_data(model) + conditioning_data = self.get_conditioning_data(context, model) # TODO: Verify the noise is the right size diff --git a/invokeai/app/services/default_graphs.py b/invokeai/app/services/default_graphs.py index fd0c8f5b8d..c8347c043f 100644 --- a/invokeai/app/services/default_graphs.py +++ b/invokeai/app/services/default_graphs.py @@ -1,4 +1,5 @@ from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation +from ..invocations.compel import CompelInvocation from ..invocations.params import ParamIntInvocation from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph from .item_storage import ItemStorageABC @@ -16,24 +17,32 @@ def create_text_to_image() -> LibraryGraph: nodes={ 'width': ParamIntInvocation(id='width', a=512), 'height': ParamIntInvocation(id='height', a=512), + 'seed': ParamIntInvocation(id='seed', a=-1), '3': NoiseInvocation(id='3'), - '4': TextToLatentsInvocation(id='4'), - '5': LatentsToImageInvocation(id='5') + '4': CompelInvocation(id='4'), + '5': CompelInvocation(id='5'), + '6': TextToLatentsInvocation(id='6'), + '7': LatentsToImageInvocation(id='7'), }, edges=[ Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')), Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')), - Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='4', field='noise')), - Edge(source=EdgeConnection(node_id='4', field='latents'), destination=EdgeConnection(node_id='5', field='latents')), + Edge(source=EdgeConnection(node_id='seed', field='a'), destination=EdgeConnection(node_id='3', field='seed')), + Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='6', field='noise')), + Edge(source=EdgeConnection(node_id='6', field='latents'), destination=EdgeConnection(node_id='7', field='latents')), + Edge(source=EdgeConnection(node_id='4', field='conditioning'), destination=EdgeConnection(node_id='6', field='positive_conditioning')), + Edge(source=EdgeConnection(node_id='5', field='conditioning'), destination=EdgeConnection(node_id='6', field='negative_conditioning')), ] ), exposed_inputs=[ - ExposedNodeInput(node_path='4', field='prompt', alias='prompt'), + ExposedNodeInput(node_path='4', field='prompt', alias='positive_prompt'), + ExposedNodeInput(node_path='5', field='prompt', alias='negative_prompt'), ExposedNodeInput(node_path='width', field='a', alias='width'), - ExposedNodeInput(node_path='height', field='a', alias='height') + ExposedNodeInput(node_path='height', field='a', alias='height'), + ExposedNodeInput(node_path='seed', field='a', alias='seed'), ], exposed_outputs=[ - ExposedNodeOutput(node_path='5', field='image', alias='image') + ExposedNodeOutput(node_path='7', field='image', alias='image') ]) diff --git a/invokeai/backend/web/invoke_ai_web_server.py b/invokeai/backend/web/invoke_ai_web_server.py index 84478d5cb6..97687bd2bf 100644 --- a/invokeai/backend/web/invoke_ai_web_server.py +++ b/invokeai/backend/web/invoke_ai_web_server.py @@ -78,7 +78,6 @@ class InvokeAIWebServer: mimetypes.add_type("application/javascript", ".js") mimetypes.add_type("text/css", ".css") # Socket IO - logger = True if args.web_verbose else False engineio_logger = True if args.web_verbose else False max_http_buffer_size = 10000000 diff --git a/invokeai/frontend/web/src/common/util/parseMetadata.ts b/invokeai/frontend/web/src/common/util/parseMetadata.ts index 210c1f85ab..c27833218b 100644 --- a/invokeai/frontend/web/src/common/util/parseMetadata.ts +++ b/invokeai/frontend/web/src/common/util/parseMetadata.ts @@ -1,5 +1,5 @@ import { forEach, size } from 'lodash-es'; -import { ImageField, LatentsField } from 'services/api'; +import { ImageField, LatentsField, ConditioningField } from 'services/api'; const OBJECT_TYPESTRING = '[object Object]'; const STRING_TYPESTRING = '[object String]'; @@ -74,8 +74,38 @@ const parseLatentsField = (latentsField: unknown): LatentsField | undefined => { }; }; +const parseConditioningField = ( + conditioningField: unknown +): ConditioningField | undefined => { + // Must be an object + if (!isObject(conditioningField)) { + return; + } + + // A ConditioningField must have a `conditioning_name` + if (!('conditioning_name' in conditioningField)) { + return; + } + + // A ConditioningField's `conditioning_name` must be a string + if (typeof conditioningField.conditioning_name !== 'string') { + return; + } + + // Build a valid ConditioningField + return { + conditioning_name: conditioningField.conditioning_name, + }; +}; + type NodeMetadata = { - [key: string]: string | number | boolean | ImageField | LatentsField; + [key: string]: + | string + | number + | boolean + | ImageField + | LatentsField + | ConditioningField; }; type InvokeAIMetadata = { @@ -101,7 +131,7 @@ export const parseNodeMetadata = ( return; } - // the only valid object types are ImageField and LatentsField + // the only valid object types are ImageField, LatentsField and ConditioningField if (isObject(nodeItem)) { if ('image_name' in nodeItem || 'image_type' in nodeItem) { const imageField = parseImageField(nodeItem); @@ -118,6 +148,14 @@ export const parseNodeMetadata = ( } return; } + + if ('conditioning_name' in nodeItem) { + const conditioningField = parseConditioningField(nodeItem); + if (conditioningField) { + parsed[nodeKey] = conditioningField; + } + return; + } } // otherwise we accept any string, number or boolean diff --git a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx index 21e4b9fcfb..01d6d01b48 100644 --- a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx @@ -6,6 +6,7 @@ import BooleanInputFieldComponent from './fields/BooleanInputFieldComponent'; import EnumInputFieldComponent from './fields/EnumInputFieldComponent'; import ImageInputFieldComponent from './fields/ImageInputFieldComponent'; import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent'; +import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent'; import ModelInputFieldComponent from './fields/ModelInputFieldComponent'; import NumberInputFieldComponent from './fields/NumberInputFieldComponent'; import StringInputFieldComponent from './fields/StringInputFieldComponent'; @@ -84,6 +85,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => { ); } + if (type === 'conditioning' && template.type === 'conditioning') { + return ( + + ); + } + if (type === 'model' && template.type === 'model') { return ( +) => { + const { nodeId, field } = props; + + return null; +}; + +export default memo(ConditioningInputFieldComponent); diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index 01497651e3..73bd7bb0a1 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -11,6 +11,7 @@ export const FIELD_TYPE_MAP: Record = { enum: 'enum', ImageField: 'image', LatentsField: 'latents', + ConditioningField: 'conditioning', model: 'model', array: 'array', }; @@ -63,6 +64,12 @@ export const FIELDS: Record = { title: 'Latents', description: 'Latents may be passed between nodes.', }, + conditioning: { + color: 'cyan', + colorCssVar: getColorTokenCssVariable('cyan'), + title: 'Conditioning', + description: 'Conditioning may be passed between nodes.', + }, model: { color: 'teal', colorCssVar: getColorTokenCssVariable('teal'), diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index 4b5548e351..568c5fa831 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -56,6 +56,7 @@ export type FieldType = | 'enum' | 'image' | 'latents' + | 'conditioning' | 'model' | 'array'; @@ -74,6 +75,7 @@ export type InputFieldValue = | BooleanInputFieldValue | ImageInputFieldValue | LatentsInputFieldValue + | ConditioningInputFieldValue | EnumInputFieldValue | ModelInputFieldValue | ArrayInputFieldValue; @@ -91,6 +93,7 @@ export type InputFieldTemplate = | BooleanInputFieldTemplate | ImageInputFieldTemplate | LatentsInputFieldTemplate + | ConditioningInputFieldTemplate | EnumInputFieldTemplate | ModelInputFieldTemplate | ArrayInputFieldTemplate; @@ -162,6 +165,11 @@ export type LatentsInputFieldValue = FieldValueBase & { value?: undefined; }; +export type ConditioningInputFieldValue = FieldValueBase & { + type: 'conditioning'; + value?: undefined; +}; + export type ImageInputFieldValue = FieldValueBase & { type: 'image'; value?: Pick; @@ -229,6 +237,11 @@ export type LatentsInputFieldTemplate = InputFieldTemplateBase & { type: 'latents'; }; +export type ConditioningInputFieldTemplate = InputFieldTemplateBase & { + default: undefined; + type: 'conditioning'; +}; + export type EnumInputFieldTemplate = InputFieldTemplateBase & { default: string | number; type: 'enum'; diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts index df895ba4af..9ce942f797 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts @@ -9,6 +9,7 @@ import { ImageInputFieldTemplate, IntegerInputFieldTemplate, LatentsInputFieldTemplate, + ConditioningInputFieldTemplate, StringInputFieldTemplate, ModelInputFieldTemplate, InputFieldTemplateBase, @@ -196,6 +197,21 @@ const buildLatentsInputFieldTemplate = ({ return template; }; +const buildConditioningInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): ConditioningInputFieldTemplate => { + const template: ConditioningInputFieldTemplate = { + ...baseField, + type: 'conditioning', + inputRequirement: 'always', + inputKind: 'connection', + default: schemaObject.default ?? undefined, + }; + + return template; +}; + const buildEnumInputFieldTemplate = ({ schemaObject, baseField, @@ -266,6 +282,9 @@ export const buildInputFieldTemplate = ( if (['latents'].includes(fieldType)) { return buildLatentsInputFieldTemplate({ schemaObject, baseField }); } + if (['conditioning'].includes(fieldType)) { + return buildConditioningInputFieldTemplate({ schemaObject, baseField }); + } if (['model'].includes(fieldType)) { return buildModelInputFieldTemplate({ schemaObject, baseField }); } diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts index f2db2b5dc4..9221e5f7ac 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts @@ -48,6 +48,10 @@ export const buildInputFieldValue = ( fieldValue.value = undefined; } + if (template.type === 'conditioning') { + fieldValue.value = undefined; + } + if (template.type === 'model') { fieldValue.value = undefined; } diff --git a/tests/nodes/test_node_graph.py b/tests/nodes/test_node_graph.py index c7693b59c9..82818414b2 100644 --- a/tests/nodes/test_node_graph.py +++ b/tests/nodes/test_node_graph.py @@ -463,16 +463,16 @@ def test_graph_subgraph_t2i(): n4 = ShowImageInvocation(id = "4") g.add_node(n4) - g.add_edge(create_edge("1.5","image","4","image")) + g.add_edge(create_edge("1.7","image","4","image")) # Validate dg = g.nx_graph_flat() - assert set(dg.nodes) == set(['1.width', '1.height', '1.3', '1.4', '1.5', '2', '3', '4']) + assert set(dg.nodes) == set(['1.width', '1.height', '1.seed', '1.3', '1.4', '1.5', '1.6', '1.7', '2', '3', '4']) expected_edges = [(f'1.{e.source.node_id}',f'1.{e.destination.node_id}') for e in lg.graph.edges] expected_edges.extend([ ('2','1.width'), ('3','1.height'), - ('1.5','4') + ('1.7','4') ]) print(expected_edges) print(list(dg.edges))