diff --git a/.github/workflows/pyflakes.yml b/.github/workflows/pyflakes.yml deleted file mode 100644 index 4bda2dd103..0000000000 --- a/.github/workflows/pyflakes.yml +++ /dev/null @@ -1,20 +0,0 @@ -on: - pull_request: - push: - branches: - - main - - development - - 'release-candidate-*' - -jobs: - pyflakes: - name: runner / pyflakes - if: github.event.pull_request.draft == false - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - name: pyflakes - uses: reviewdog/action-pyflakes@v1 - with: - github_token: ${{ secrets.GITHUB_TOKEN }} - reporter: github-pr-review diff --git a/.github/workflows/style-checks.yml b/.github/workflows/style-checks.yml index 08ff8ba402..121af62e1b 100644 --- a/.github/workflows/style-checks.yml +++ b/.github/workflows/style-checks.yml @@ -18,8 +18,7 @@ jobs: - name: Install dependencies with pip run: | - pip install black flake8 Flake8-pyproject isort + pip install ruff - - run: isort --check-only . - - run: black --check . - - run: flake8 + - run: ruff check --output-format=github . + - run: ruff format --check . diff --git a/installer/lib/messages.py b/installer/lib/messages.py index e4c03bbfd2..6d95eaff59 100644 --- a/installer/lib/messages.py +++ b/installer/lib/messages.py @@ -137,7 +137,7 @@ def dest_path(dest=None) -> Path: path_completer = PathCompleter( only_directories=True, expanduser=True, - get_paths=lambda: [browse_start], + get_paths=lambda: [browse_start], # noqa: B023 # get_paths=lambda: [".."].extend(list(browse_start.iterdir())) ) @@ -149,7 +149,7 @@ def dest_path(dest=None) -> Path: completer=path_completer, default=str(browse_start) + os.sep, vi_mode=True, - complete_while_typing=True + complete_while_typing=True, # Test that this is not needed on Windows # complete_style=CompleteStyle.READLINE_LIKE, ) diff --git a/invokeai/app/api/events.py b/invokeai/app/api/events.py index 40dfdb2c71..2ac07e6dfe 100644 --- a/invokeai/app/api/events.py +++ b/invokeai/app/api/events.py @@ -28,7 +28,7 @@ class FastAPIEventService(EventServiceBase): self.__queue.put(None) def dispatch(self, event_name: str, payload: Any) -> None: - self.__queue.put(dict(event_name=event_name, payload=payload)) + self.__queue.put({"event_name": event_name, "payload": payload}) async def __dispatch_from_queue(self, stop_event: threading.Event): """Get events on from the queue and dispatch them, from the correct thread""" diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index afa7d8df82..cf3d31cc38 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -55,7 +55,7 @@ async def list_models( ) -> ModelsList: """Gets a list of models""" if base_models and len(base_models) > 0: - models_raw = list() + models_raw = [] for base_model in base_models: models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)) else: diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 51aa14c75b..19bdd084e2 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -130,7 +130,7 @@ def custom_openapi() -> dict[str, Any]: # Add all outputs all_invocations = BaseInvocation.get_invocations() output_types = set() - output_type_titles = dict() + output_type_titles = {} for invoker in all_invocations: output_type = signature(invoker.invoke).return_annotation output_types.add(output_type) @@ -171,12 +171,12 @@ def custom_openapi() -> dict[str, Any]: # print(f"Config with name {name} already defined") continue - openapi_schema["components"]["schemas"][name] = dict( - title=name, - description="An enumeration.", - type="string", - enum=list(v.value for v in model_config_format_enum), - ) + openapi_schema["components"]["schemas"][name] = { + "title": name, + "description": "An enumeration.", + "type": "string", + "enum": [v.value for v in model_config_format_enum], + } app.openapi_schema = openapi_schema return app.openapi_schema diff --git a/invokeai/app/invocations/__init__.py b/invokeai/app/invocations/__init__.py index 32cf73d215..718c4a7c38 100644 --- a/invokeai/app/invocations/__init__.py +++ b/invokeai/app/invocations/__init__.py @@ -25,4 +25,4 @@ spec.loader.exec_module(module) # add core nodes to __all__ python_files = filter(lambda f: not f.name.startswith("_"), Path(__file__).parent.glob("*.py")) -__all__ = list(f.stem for f in python_files) # type: ignore +__all__ = [f.stem for f in python_files] # type: ignore diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index ea79e0cceb..1b3e535d34 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -236,35 +236,35 @@ def InputField( Ignored for non-collection fields. """ - json_schema_extra_: dict[str, Any] = dict( - input=input, - ui_type=ui_type, - ui_component=ui_component, - ui_hidden=ui_hidden, - ui_order=ui_order, - item_default=item_default, - ui_choice_labels=ui_choice_labels, - _field_kind="input", - ) + json_schema_extra_: dict[str, Any] = { + "input": input, + "ui_type": ui_type, + "ui_component": ui_component, + "ui_hidden": ui_hidden, + "ui_order": ui_order, + "item_default": item_default, + "ui_choice_labels": ui_choice_labels, + "_field_kind": "input", + } - field_args = dict( - default=default, - default_factory=default_factory, - title=title, - description=description, - pattern=pattern, - strict=strict, - gt=gt, - ge=ge, - lt=lt, - le=le, - multiple_of=multiple_of, - allow_inf_nan=allow_inf_nan, - max_digits=max_digits, - decimal_places=decimal_places, - min_length=min_length, - max_length=max_length, - ) + field_args = { + "default": default, + "default_factory": default_factory, + "title": title, + "description": description, + "pattern": pattern, + "strict": strict, + "gt": gt, + "ge": ge, + "lt": lt, + "le": le, + "multiple_of": multiple_of, + "allow_inf_nan": allow_inf_nan, + "max_digits": max_digits, + "decimal_places": decimal_places, + "min_length": min_length, + "max_length": max_length, + } """ Invocation definitions have their fields typed correctly for their `invoke()` functions. @@ -299,24 +299,24 @@ def InputField( # because we are manually making fields optional, we need to store the original required bool for reference later if default is PydanticUndefined and default_factory is PydanticUndefined: - json_schema_extra_.update(dict(orig_required=True)) + json_schema_extra_.update({"orig_required": True}) else: - json_schema_extra_.update(dict(orig_required=False)) + json_schema_extra_.update({"orig_required": False}) # make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one if (input is Input.Any or input is Input.Connection) and default_factory is PydanticUndefined: default_ = None if default is PydanticUndefined else default - provided_args.update(dict(default=default_)) + provided_args.update({"default": default_}) if default is not PydanticUndefined: # before invoking, we'll grab the original default value and set it on the field if the field wasn't provided a value - json_schema_extra_.update(dict(default=default)) - json_schema_extra_.update(dict(orig_default=default)) + json_schema_extra_.update({"default": default}) + json_schema_extra_.update({"orig_default": default}) elif default is not PydanticUndefined and default_factory is PydanticUndefined: default_ = default - provided_args.update(dict(default=default_)) - json_schema_extra_.update(dict(orig_default=default_)) + provided_args.update({"default": default_}) + json_schema_extra_.update({"orig_default": default_}) elif default_factory is not PydanticUndefined: - provided_args.update(dict(default_factory=default_factory)) + provided_args.update({"default_factory": default_factory}) # TODO: cannot serialize default_factory... # json_schema_extra_.update(dict(orig_default_factory=default_factory)) @@ -383,12 +383,12 @@ def OutputField( decimal_places=decimal_places, min_length=min_length, max_length=max_length, - json_schema_extra=dict( - ui_type=ui_type, - ui_hidden=ui_hidden, - ui_order=ui_order, - _field_kind="output", - ), + json_schema_extra={ + "ui_type": ui_type, + "ui_hidden": ui_hidden, + "ui_order": ui_order, + "_field_kind": "output", + }, ) @@ -460,14 +460,14 @@ class BaseInvocationOutput(BaseModel): @classmethod def get_output_types(cls) -> Iterable[str]: - return map(lambda i: get_type(i), BaseInvocationOutput.get_outputs()) + return (get_type(i) for i in BaseInvocationOutput.get_outputs()) @staticmethod def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None: # Because we use a pydantic Literal field with default value for the invocation type, # it will be typed as optional in the OpenAPI schema. Make it required manually. if "required" not in schema or not isinstance(schema["required"], list): - schema["required"] = list() + schema["required"] = [] schema["required"].extend(["type"]) model_config = ConfigDict( @@ -527,16 +527,11 @@ class BaseInvocation(ABC, BaseModel): @classmethod def get_invocations_map(cls) -> dict[str, BaseInvocation]: # Get the type strings out of the literals and into a dictionary - return dict( - map( - lambda i: (get_type(i), i), - BaseInvocation.get_invocations(), - ) - ) + return {get_type(i): i for i in BaseInvocation.get_invocations()} @classmethod def get_invocation_types(cls) -> Iterable[str]: - return map(lambda i: get_type(i), BaseInvocation.get_invocations()) + return (get_type(i) for i in BaseInvocation.get_invocations()) @classmethod def get_output_type(cls) -> BaseInvocationOutput: @@ -555,7 +550,7 @@ class BaseInvocation(ABC, BaseModel): if uiconfig and hasattr(uiconfig, "version"): schema["version"] = uiconfig.version if "required" not in schema or not isinstance(schema["required"], list): - schema["required"] = list() + schema["required"] = [] schema["required"].extend(["type", "id"]) @abstractmethod @@ -609,15 +604,15 @@ class BaseInvocation(ABC, BaseModel): id: str = Field( default_factory=uuid_string, description="The id of this instance of an invocation. Must be unique among all instances of invocations.", - json_schema_extra=dict(_field_kind="internal"), + json_schema_extra={"_field_kind": "internal"}, ) is_intermediate: bool = Field( default=False, description="Whether or not this is an intermediate invocation.", - json_schema_extra=dict(ui_type=UIType.IsIntermediate, _field_kind="internal"), + json_schema_extra={"ui_type": UIType.IsIntermediate, "_field_kind": "internal"}, ) use_cache: bool = Field( - default=True, description="Whether or not to use the cache", json_schema_extra=dict(_field_kind="internal") + default=True, description="Whether or not to use the cache", json_schema_extra={"_field_kind": "internal"} ) UIConfig: ClassVar[Type[UIConfigBase]] @@ -651,7 +646,7 @@ class _Model(BaseModel): # Get all pydantic model attrs, methods, etc -RESERVED_PYDANTIC_FIELD_NAMES = set(map(lambda m: m[0], inspect.getmembers(_Model()))) +RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())} def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None: @@ -666,9 +661,7 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None field_kind = ( # _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file - field.json_schema_extra.get("_field_kind", None) - if field.json_schema_extra - else None + field.json_schema_extra.get("_field_kind", None) if field.json_schema_extra else None ) # must have a field_kind @@ -729,7 +722,7 @@ def invocation( # Add OpenAPI schema extras uiconf_name = cls.__qualname__ + ".UIConfig" if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name: - cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict()) + cls.UIConfig = type(uiconf_name, (UIConfigBase,), {}) if title is not None: cls.UIConfig.title = title if tags is not None: @@ -756,7 +749,7 @@ def invocation( invocation_type_annotation = Literal[invocation_type] # type: ignore invocation_type_field = Field( - title="type", default=invocation_type, json_schema_extra=dict(_field_kind="internal") + title="type", default=invocation_type, json_schema_extra={"_field_kind": "internal"} ) docstring = cls.__doc__ @@ -802,7 +795,7 @@ def invocation_output( # Add the output type to the model. output_type_annotation = Literal[output_type] # type: ignore - output_type_field = Field(title="type", default=output_type, json_schema_extra=dict(_field_kind="internal")) + output_type_field = Field(title="type", default=output_type, json_schema_extra={"_field_kind": "internal"}) docstring = cls.__doc__ cls = create_model( @@ -834,7 +827,7 @@ WorkflowFieldValidator = TypeAdapter(WorkflowField) class WithWorkflow(BaseModel): workflow: Optional[WorkflowField] = Field( - default=None, description=FieldDescriptions.workflow, json_schema_extra=dict(_field_kind="internal") + default=None, description=FieldDescriptions.workflow, json_schema_extra={"_field_kind": "internal"} ) @@ -852,5 +845,5 @@ MetadataFieldValidator = TypeAdapter(MetadataField) class WithMetadata(BaseModel): metadata: Optional[MetadataField] = Field( - default=None, description=FieldDescriptions.metadata, json_schema_extra=dict(_field_kind="internal") + default=None, description=FieldDescriptions.metadata, json_schema_extra={"_field_kind": "internal"} ) diff --git a/invokeai/app/invocations/facetools.py b/invokeai/app/invocations/facetools.py index 0bb24ef69d..41d1ef1e4b 100644 --- a/invokeai/app/invocations/facetools.py +++ b/invokeai/app/invocations/facetools.py @@ -131,7 +131,7 @@ def prepare_faces_list( deduped_faces: list[FaceResultData] = [] if len(face_result_list) == 0: - return list() + return [] for candidate in face_result_list: should_add = True @@ -210,7 +210,7 @@ def generate_face_box_mask( # Check if any face is detected. if results.multi_face_landmarks: # type: ignore # this are via protobuf and not typed # Search for the face_id in the detected faces. - for face_id, face_landmarks in enumerate(results.multi_face_landmarks): # type: ignore #this are via protobuf and not typed + for _face_id, face_landmarks in enumerate(results.multi_face_landmarks): # type: ignore #this are via protobuf and not typed # Get the bounding box of the face mesh. x_coordinates = [landmark.x for landmark in face_landmarks.landmark] y_coordinates = [landmark.y for landmark in face_landmarks.landmark] diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index c9a0ca4423..9412aec39b 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -77,7 +77,7 @@ if choose_torch_device() == torch.device("mps"): DEFAULT_PRECISION = choose_precision(choose_torch_device()) -SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))] +SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] @invocation_output("scheduler_output") @@ -1105,7 +1105,7 @@ class BlendLatentsInvocation(BaseInvocation): latents_b = context.services.latents.get(self.latents_b.latents_name) if latents_a.shape != latents_b.shape: - raise "Latents to blend must be the same size." + raise Exception("Latents to blend must be the same size.") # TODO: device = choose_torch_device() diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py index 585122d091..defc61275f 100644 --- a/invokeai/app/invocations/math.py +++ b/invokeai/app/invocations/math.py @@ -145,17 +145,17 @@ INTEGER_OPERATIONS = Literal[ ] -INTEGER_OPERATIONS_LABELS = dict( - ADD="Add A+B", - SUB="Subtract A-B", - MUL="Multiply A*B", - DIV="Divide A/B", - EXP="Exponentiate A^B", - MOD="Modulus A%B", - ABS="Absolute Value of A", - MIN="Minimum(A,B)", - MAX="Maximum(A,B)", -) +INTEGER_OPERATIONS_LABELS = { + "ADD": "Add A+B", + "SUB": "Subtract A-B", + "MUL": "Multiply A*B", + "DIV": "Divide A/B", + "EXP": "Exponentiate A^B", + "MOD": "Modulus A%B", + "ABS": "Absolute Value of A", + "MIN": "Minimum(A,B)", + "MAX": "Maximum(A,B)", +} @invocation( @@ -231,17 +231,17 @@ FLOAT_OPERATIONS = Literal[ ] -FLOAT_OPERATIONS_LABELS = dict( - ADD="Add A+B", - SUB="Subtract A-B", - MUL="Multiply A*B", - DIV="Divide A/B", - EXP="Exponentiate A^B", - ABS="Absolute Value of A", - SQRT="Square Root of A", - MIN="Minimum(A,B)", - MAX="Maximum(A,B)", -) +FLOAT_OPERATIONS_LABELS = { + "ADD": "Add A+B", + "SUB": "Subtract A-B", + "MUL": "Multiply A*B", + "DIV": "Divide A/B", + "EXP": "Exponentiate A^B", + "ABS": "Absolute Value of A", + "SQRT": "Square Root of A", + "MIN": "Minimum(A,B)", + "MAX": "Maximum(A,B)", +} @invocation( @@ -266,7 +266,7 @@ class FloatMathInvocation(BaseInvocation): raise ValueError("Cannot divide by zero") elif info.data["operation"] == "EXP" and info.data["a"] == 0 and v < 0: raise ValueError("Cannot raise zero to a negative power") - elif info.data["operation"] == "EXP" and type(info.data["a"] ** v) is complex: + elif info.data["operation"] == "EXP" and isinstance(info.data["a"] ** v, complex): raise ValueError("Root operation resulted in a complex number") return v diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index 699930fc06..45b5ed61b1 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -54,7 +54,7 @@ ORT_TO_NP_TYPE = { "tensor(double)": np.float64, } -PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))] +PRECISION_VALUES = Literal[tuple(ORT_TO_NP_TYPE.keys())] @invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning", version="1.0.0") @@ -252,7 +252,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation): scheduler.set_timesteps(self.steps) latents = latents * np.float64(scheduler.init_noise_sigma) - extra_step_kwargs = dict() + extra_step_kwargs = {} if "eta" in set(inspect.signature(scheduler.step).parameters.keys()): extra_step_kwargs.update( eta=0.0, diff --git a/invokeai/app/invocations/param_easing.py b/invokeai/app/invocations/param_easing.py index 0e86fb978b..dccd18f754 100644 --- a/invokeai/app/invocations/param_easing.py +++ b/invokeai/app/invocations/param_easing.py @@ -100,7 +100,7 @@ EASING_FUNCTIONS_MAP = { "BounceInOut": BounceEaseInOut, } -EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))] +EASING_FUNCTION_KEYS = Literal[tuple(EASING_FUNCTIONS_MAP.keys())] # actually I think for now could just use CollectionOutput (which is list[Any] @@ -161,7 +161,7 @@ class StepParamEasingInvocation(BaseInvocation): easing_class = EASING_FUNCTIONS_MAP[self.easing] if log_diagnostics: context.services.logger.debug("easing class: " + str(easing_class)) - easing_list = list() + easing_list = [] if self.mirror: # "expected" mirroring # if number of steps is even, squeeze duration down to (number_of_steps)/2 # and create reverse copy of list to append @@ -178,7 +178,7 @@ class StepParamEasingInvocation(BaseInvocation): end=self.end_value, duration=base_easing_duration - 1, ) - base_easing_vals = list() + base_easing_vals = [] for step_index in range(base_easing_duration): easing_val = easing_function.ease(step_index) base_easing_vals.append(easing_val) diff --git a/invokeai/app/services/board_image_records/board_image_records_sqlite.py b/invokeai/app/services/board_image_records/board_image_records_sqlite.py index 9f4e4379bc..02bafd00ec 100644 --- a/invokeai/app/services/board_image_records/board_image_records_sqlite.py +++ b/invokeai/app/services/board_image_records/board_image_records_sqlite.py @@ -139,7 +139,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase): (board_id,), ) result = cast(list[sqlite3.Row], self._cursor.fetchall()) - images = list(map(lambda r: deserialize_image_record(dict(r)), result)) + images = [deserialize_image_record(dict(r)) for r in result] self._cursor.execute( """--sql @@ -167,7 +167,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase): (board_id,), ) result = cast(list[sqlite3.Row], self._cursor.fetchall()) - image_names = list(map(lambda r: r[0], result)) + image_names = [r[0] for r in result] return image_names except sqlite3.Error as e: self._conn.rollback() diff --git a/invokeai/app/services/board_records/board_records_sqlite.py b/invokeai/app/services/board_records/board_records_sqlite.py index 9e3423ab19..ef507def2a 100644 --- a/invokeai/app/services/board_records/board_records_sqlite.py +++ b/invokeai/app/services/board_records/board_records_sqlite.py @@ -199,7 +199,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase): ) result = cast(list[sqlite3.Row], self._cursor.fetchall()) - boards = list(map(lambda r: deserialize_board_record(dict(r)), result)) + boards = [deserialize_board_record(dict(r)) for r in result] # Get the total number of boards self._cursor.execute( @@ -236,7 +236,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase): ) result = cast(list[sqlite3.Row], self._cursor.fetchall()) - boards = list(map(lambda r: deserialize_board_record(dict(r)), result)) + boards = [deserialize_board_record(dict(r)) for r in result] return boards diff --git a/invokeai/app/services/config/config_base.py b/invokeai/app/services/config/config_base.py index 9405c1dfae..66c456d311 100644 --- a/invokeai/app/services/config/config_base.py +++ b/invokeai/app/services/config/config_base.py @@ -55,7 +55,7 @@ class InvokeAISettings(BaseSettings): """ cls = self.__class__ type = get_args(get_type_hints(cls)["type"])[0] - field_dict = dict({type: dict()}) + field_dict = {type: {}} for name, field in self.model_fields.items(): if name in cls._excluded_from_yaml(): continue @@ -64,7 +64,7 @@ class InvokeAISettings(BaseSettings): ) value = getattr(self, name) if category not in field_dict[type]: - field_dict[type][category] = dict() + field_dict[type][category] = {} # keep paths as strings to make it easier to read field_dict[type][category][name] = str(value) if isinstance(value, Path) else value conf = OmegaConf.create(field_dict) @@ -89,7 +89,7 @@ class InvokeAISettings(BaseSettings): # create an upcase version of the environment in # order to achieve case-insensitive environment # variables (the way Windows does) - upcase_environ = dict() + upcase_environ = {} for key, value in os.environ.items(): upcase_environ[key.upper()] = value diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index f0e9dbcda4..30c6694ddb 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -188,18 +188,18 @@ DEFAULT_MAX_VRAM = 0.5 class Categories(object): - WebServer = dict(category="Web Server") - Features = dict(category="Features") - Paths = dict(category="Paths") - Logging = dict(category="Logging") - Development = dict(category="Development") - Other = dict(category="Other") - ModelCache = dict(category="Model Cache") - Device = dict(category="Device") - Generation = dict(category="Generation") - Queue = dict(category="Queue") - Nodes = dict(category="Nodes") - MemoryPerformance = dict(category="Memory/Performance") + WebServer = {"category": "Web Server"} + Features = {"category": "Features"} + Paths = {"category": "Paths"} + Logging = {"category": "Logging"} + Development = {"category": "Development"} + Other = {"category": "Other"} + ModelCache = {"category": "Model Cache"} + Device = {"category": "Device"} + Generation = {"category": "Generation"} + Queue = {"category": "Queue"} + Nodes = {"category": "Nodes"} + MemoryPerformance = {"category": "Memory/Performance"} class InvokeAIAppConfig(InvokeAISettings): @@ -482,7 +482,7 @@ def _find_root() -> Path: venv = Path(os.environ.get("VIRTUAL_ENV") or ".") if os.environ.get("INVOKEAI_ROOT"): root = Path(os.environ["INVOKEAI_ROOT"]) - elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]): + elif any((venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]): root = (venv.parent).resolve() else: root = Path("~/invokeai").expanduser().resolve() diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index ad00815151..dd4152e609 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -27,7 +27,7 @@ class EventServiceBase: payload["timestamp"] = get_timestamp() self.dispatch( event_name=EventServiceBase.queue_event, - payload=dict(event=event_name, data=payload), + payload={"event": event_name, "data": payload}, ) # Define events here for every event in the system. @@ -48,18 +48,18 @@ class EventServiceBase: """Emitted when there is generation progress""" self.__emit_queue_event( event_name="generator_progress", - payload=dict( - queue_id=queue_id, - queue_item_id=queue_item_id, - queue_batch_id=queue_batch_id, - graph_execution_state_id=graph_execution_state_id, - node_id=node.get("id"), - source_node_id=source_node_id, - progress_image=progress_image.model_dump() if progress_image is not None else None, - step=step, - order=order, - total_steps=total_steps, - ), + payload={ + "queue_id": queue_id, + "queue_item_id": queue_item_id, + "queue_batch_id": queue_batch_id, + "graph_execution_state_id": graph_execution_state_id, + "node_id": node.get("id"), + "source_node_id": source_node_id, + "progress_image": progress_image.model_dump() if progress_image is not None else None, + "step": step, + "order": order, + "total_steps": total_steps, + }, ) def emit_invocation_complete( @@ -75,15 +75,15 @@ class EventServiceBase: """Emitted when an invocation has completed""" self.__emit_queue_event( event_name="invocation_complete", - payload=dict( - queue_id=queue_id, - queue_item_id=queue_item_id, - queue_batch_id=queue_batch_id, - graph_execution_state_id=graph_execution_state_id, - node=node, - source_node_id=source_node_id, - result=result, - ), + payload={ + "queue_id": queue_id, + "queue_item_id": queue_item_id, + "queue_batch_id": queue_batch_id, + "graph_execution_state_id": graph_execution_state_id, + "node": node, + "source_node_id": source_node_id, + "result": result, + }, ) def emit_invocation_error( @@ -100,16 +100,16 @@ class EventServiceBase: """Emitted when an invocation has completed""" self.__emit_queue_event( event_name="invocation_error", - payload=dict( - queue_id=queue_id, - queue_item_id=queue_item_id, - queue_batch_id=queue_batch_id, - graph_execution_state_id=graph_execution_state_id, - node=node, - source_node_id=source_node_id, - error_type=error_type, - error=error, - ), + payload={ + "queue_id": queue_id, + "queue_item_id": queue_item_id, + "queue_batch_id": queue_batch_id, + "graph_execution_state_id": graph_execution_state_id, + "node": node, + "source_node_id": source_node_id, + "error_type": error_type, + "error": error, + }, ) def emit_invocation_started( @@ -124,14 +124,14 @@ class EventServiceBase: """Emitted when an invocation has started""" self.__emit_queue_event( event_name="invocation_started", - payload=dict( - queue_id=queue_id, - queue_item_id=queue_item_id, - queue_batch_id=queue_batch_id, - graph_execution_state_id=graph_execution_state_id, - node=node, - source_node_id=source_node_id, - ), + payload={ + "queue_id": queue_id, + "queue_item_id": queue_item_id, + "queue_batch_id": queue_batch_id, + "graph_execution_state_id": graph_execution_state_id, + "node": node, + "source_node_id": source_node_id, + }, ) def emit_graph_execution_complete( @@ -140,12 +140,12 @@ class EventServiceBase: """Emitted when a session has completed all invocations""" self.__emit_queue_event( event_name="graph_execution_state_complete", - payload=dict( - queue_id=queue_id, - queue_item_id=queue_item_id, - queue_batch_id=queue_batch_id, - graph_execution_state_id=graph_execution_state_id, - ), + payload={ + "queue_id": queue_id, + "queue_item_id": queue_item_id, + "queue_batch_id": queue_batch_id, + "graph_execution_state_id": graph_execution_state_id, + }, ) def emit_model_load_started( @@ -162,16 +162,16 @@ class EventServiceBase: """Emitted when a model is requested""" self.__emit_queue_event( event_name="model_load_started", - payload=dict( - queue_id=queue_id, - queue_item_id=queue_item_id, - queue_batch_id=queue_batch_id, - graph_execution_state_id=graph_execution_state_id, - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=submodel, - ), + payload={ + "queue_id": queue_id, + "queue_item_id": queue_item_id, + "queue_batch_id": queue_batch_id, + "graph_execution_state_id": graph_execution_state_id, + "model_name": model_name, + "base_model": base_model, + "model_type": model_type, + "submodel": submodel, + }, ) def emit_model_load_completed( @@ -189,19 +189,19 @@ class EventServiceBase: """Emitted when a model is correctly loaded (returns model info)""" self.__emit_queue_event( event_name="model_load_completed", - payload=dict( - queue_id=queue_id, - queue_item_id=queue_item_id, - queue_batch_id=queue_batch_id, - graph_execution_state_id=graph_execution_state_id, - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=submodel, - hash=model_info.hash, - location=str(model_info.location), - precision=str(model_info.precision), - ), + payload={ + "queue_id": queue_id, + "queue_item_id": queue_item_id, + "queue_batch_id": queue_batch_id, + "graph_execution_state_id": graph_execution_state_id, + "model_name": model_name, + "base_model": base_model, + "model_type": model_type, + "submodel": submodel, + "hash": model_info.hash, + "location": str(model_info.location), + "precision": str(model_info.precision), + }, ) def emit_session_retrieval_error( @@ -216,14 +216,14 @@ class EventServiceBase: """Emitted when session retrieval fails""" self.__emit_queue_event( event_name="session_retrieval_error", - payload=dict( - queue_id=queue_id, - queue_item_id=queue_item_id, - queue_batch_id=queue_batch_id, - graph_execution_state_id=graph_execution_state_id, - error_type=error_type, - error=error, - ), + payload={ + "queue_id": queue_id, + "queue_item_id": queue_item_id, + "queue_batch_id": queue_batch_id, + "graph_execution_state_id": graph_execution_state_id, + "error_type": error_type, + "error": error, + }, ) def emit_invocation_retrieval_error( @@ -239,15 +239,15 @@ class EventServiceBase: """Emitted when invocation retrieval fails""" self.__emit_queue_event( event_name="invocation_retrieval_error", - payload=dict( - queue_id=queue_id, - queue_item_id=queue_item_id, - queue_batch_id=queue_batch_id, - graph_execution_state_id=graph_execution_state_id, - node_id=node_id, - error_type=error_type, - error=error, - ), + payload={ + "queue_id": queue_id, + "queue_item_id": queue_item_id, + "queue_batch_id": queue_batch_id, + "graph_execution_state_id": graph_execution_state_id, + "node_id": node_id, + "error_type": error_type, + "error": error, + }, ) def emit_session_canceled( @@ -260,12 +260,12 @@ class EventServiceBase: """Emitted when a session is canceled""" self.__emit_queue_event( event_name="session_canceled", - payload=dict( - queue_id=queue_id, - queue_item_id=queue_item_id, - queue_batch_id=queue_batch_id, - graph_execution_state_id=graph_execution_state_id, - ), + payload={ + "queue_id": queue_id, + "queue_item_id": queue_item_id, + "queue_batch_id": queue_batch_id, + "graph_execution_state_id": graph_execution_state_id, + }, ) def emit_queue_item_status_changed( @@ -277,39 +277,39 @@ class EventServiceBase: """Emitted when a queue item's status changes""" self.__emit_queue_event( event_name="queue_item_status_changed", - payload=dict( - queue_id=queue_status.queue_id, - queue_item=dict( - queue_id=session_queue_item.queue_id, - item_id=session_queue_item.item_id, - status=session_queue_item.status, - batch_id=session_queue_item.batch_id, - session_id=session_queue_item.session_id, - error=session_queue_item.error, - created_at=str(session_queue_item.created_at) if session_queue_item.created_at else None, - updated_at=str(session_queue_item.updated_at) if session_queue_item.updated_at else None, - started_at=str(session_queue_item.started_at) if session_queue_item.started_at else None, - completed_at=str(session_queue_item.completed_at) if session_queue_item.completed_at else None, - ), - batch_status=batch_status.model_dump(), - queue_status=queue_status.model_dump(), - ), + payload={ + "queue_id": queue_status.queue_id, + "queue_item": { + "queue_id": session_queue_item.queue_id, + "item_id": session_queue_item.item_id, + "status": session_queue_item.status, + "batch_id": session_queue_item.batch_id, + "session_id": session_queue_item.session_id, + "error": session_queue_item.error, + "created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None, + "updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None, + "started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None, + "completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None, + }, + "batch_status": batch_status.model_dump(), + "queue_status": queue_status.model_dump(), + }, ) def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None: """Emitted when a batch is enqueued""" self.__emit_queue_event( event_name="batch_enqueued", - payload=dict( - queue_id=enqueue_result.queue_id, - batch_id=enqueue_result.batch.batch_id, - enqueued=enqueue_result.enqueued, - ), + payload={ + "queue_id": enqueue_result.queue_id, + "batch_id": enqueue_result.batch.batch_id, + "enqueued": enqueue_result.enqueued, + }, ) def emit_queue_cleared(self, queue_id: str) -> None: """Emitted when the queue is cleared""" self.__emit_queue_event( event_name="queue_cleared", - payload=dict(queue_id=queue_id), + payload={"queue_id": queue_id}, ) diff --git a/invokeai/app/services/image_files/image_files_disk.py b/invokeai/app/services/image_files/image_files_disk.py index 91c1e14789..cffcb702c9 100644 --- a/invokeai/app/services/image_files/image_files_disk.py +++ b/invokeai/app/services/image_files/image_files_disk.py @@ -25,7 +25,7 @@ class DiskImageFileStorage(ImageFileStorageBase): __invoker: Invoker def __init__(self, output_folder: Union[str, Path]): - self.__cache = dict() + self.__cache = {} self.__cache_ids = Queue() self.__max_cache_size = 10 # TODO: get this from config diff --git a/invokeai/app/services/image_records/image_records_common.py b/invokeai/app/services/image_records/image_records_common.py index 5a6e5652c9..61b97c6032 100644 --- a/invokeai/app/services/image_records/image_records_common.py +++ b/invokeai/app/services/image_records/image_records_common.py @@ -90,25 +90,23 @@ class ImageRecordDeleteException(Exception): IMAGE_DTO_COLS = ", ".join( - list( - map( - lambda c: "images." + c, - [ - "image_name", - "image_origin", - "image_category", - "width", - "height", - "session_id", - "node_id", - "is_intermediate", - "created_at", - "updated_at", - "deleted_at", - "starred", - ], - ) - ) + [ + "images." + c + for c in [ + "image_name", + "image_origin", + "image_category", + "width", + "height", + "session_id", + "node_id", + "is_intermediate", + "created_at", + "updated_at", + "deleted_at", + "starred", + ] + ] ) diff --git a/invokeai/app/services/image_records/image_records_sqlite.py b/invokeai/app/services/image_records/image_records_sqlite.py index 239917b728..e0dabf1657 100644 --- a/invokeai/app/services/image_records/image_records_sqlite.py +++ b/invokeai/app/services/image_records/image_records_sqlite.py @@ -263,7 +263,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): if categories is not None: # Convert the enum values to unique list of strings - category_strings = list(map(lambda c: c.value, set(categories))) + category_strings = [c.value for c in set(categories)] # Create the correct length of placeholders placeholders = ",".join("?" * len(category_strings)) @@ -307,7 +307,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): # Build the list of images, deserializing each row self._cursor.execute(images_query, images_params) result = cast(list[sqlite3.Row], self._cursor.fetchall()) - images = list(map(lambda r: deserialize_image_record(dict(r)), result)) + images = [deserialize_image_record(dict(r)) for r in result] # Set up and execute the count query, without pagination count_query += query_conditions + ";" @@ -386,7 +386,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): """ ) result = cast(list[sqlite3.Row], self._cursor.fetchall()) - image_names = list(map(lambda r: r[0], result)) + image_names = [r[0] for r in result] self._cursor.execute( """--sql DELETE FROM images diff --git a/invokeai/app/services/images/images_base.py b/invokeai/app/services/images/images_base.py index 50a3a5fb82..b3990d08f5 100644 --- a/invokeai/app/services/images/images_base.py +++ b/invokeai/app/services/images/images_base.py @@ -21,8 +21,8 @@ class ImageServiceABC(ABC): _on_deleted_callbacks: list[Callable[[str], None]] def __init__(self) -> None: - self._on_changed_callbacks = list() - self._on_deleted_callbacks = list() + self._on_changed_callbacks = [] + self._on_deleted_callbacks = [] def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None: """Register a callback for when an image is changed""" diff --git a/invokeai/app/services/images/images_default.py b/invokeai/app/services/images/images_default.py index 8eb768a1b9..63fa78d6c8 100644 --- a/invokeai/app/services/images/images_default.py +++ b/invokeai/app/services/images/images_default.py @@ -217,18 +217,16 @@ class ImageService(ImageServiceABC): board_id, ) - image_dtos = list( - map( - lambda r: image_record_to_dto( - image_record=r, - image_url=self.__invoker.services.urls.get_image_url(r.image_name), - thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True), - board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name), - workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name), - ), - results.items, + image_dtos = [ + image_record_to_dto( + image_record=r, + image_url=self.__invoker.services.urls.get_image_url(r.image_name), + thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True), + board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name), + workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name), ) - ) + for r in results.items + ] return OffsetPaginatedResults[ImageDTO]( items=image_dtos, diff --git a/invokeai/app/services/invocation_processor/invocation_processor_base.py b/invokeai/app/services/invocation_processor/invocation_processor_base.py index 04774accc2..7947a201dd 100644 --- a/invokeai/app/services/invocation_processor/invocation_processor_base.py +++ b/invokeai/app/services/invocation_processor/invocation_processor_base.py @@ -1,5 +1,5 @@ from abc import ABC -class InvocationProcessorABC(ABC): +class InvocationProcessorABC(ABC): # noqa: B024 pass diff --git a/invokeai/app/services/invocation_processor/invocation_processor_default.py b/invokeai/app/services/invocation_processor/invocation_processor_default.py index c59fb678ef..6e0d3075ea 100644 --- a/invokeai/app/services/invocation_processor/invocation_processor_default.py +++ b/invokeai/app/services/invocation_processor/invocation_processor_default.py @@ -26,7 +26,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): self.__invoker_thread = Thread( name="invoker_processor", target=self.__process, - kwargs=dict(stop_event=self.__stop_event), + kwargs={"stop_event": self.__stop_event}, ) self.__invoker_thread.daemon = True # TODO: make async and do not use threads self.__invoker_thread.start() diff --git a/invokeai/app/services/invocation_queue/invocation_queue_memory.py b/invokeai/app/services/invocation_queue/invocation_queue_memory.py index 33e82fae18..8d6fff7052 100644 --- a/invokeai/app/services/invocation_queue/invocation_queue_memory.py +++ b/invokeai/app/services/invocation_queue/invocation_queue_memory.py @@ -14,7 +14,7 @@ class MemoryInvocationQueue(InvocationQueueABC): def __init__(self): self.__queue = Queue() - self.__cancellations = dict() + self.__cancellations = {} def get(self) -> InvocationQueueItem: item = self.__queue.get() diff --git a/invokeai/app/services/invocation_stats/invocation_stats_default.py b/invokeai/app/services/invocation_stats/invocation_stats_default.py index be019b6820..34d2cd8354 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_default.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_default.py @@ -122,7 +122,7 @@ class InvocationStatsService(InvocationStatsServiceBase): def log_stats(self): completed = set() errored = set() - for graph_id, node_log in self._stats.items(): + for graph_id, _node_log in self._stats.items(): try: current_graph_state = self._invoker.services.graph_execution_manager.get(graph_id) except Exception: @@ -142,7 +142,7 @@ class InvocationStatsService(InvocationStatsServiceBase): cache_stats = self._cache_stats[graph_id] hwm = cache_stats.high_watermark / GIG tot = cache_stats.cache_size / GIG - loaded = sum([v for v in cache_stats.loaded_model_sizes.values()]) / GIG + loaded = sum(list(cache_stats.loaded_model_sizes.values())) / GIG logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s") logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)") diff --git a/invokeai/app/services/item_storage/item_storage_base.py b/invokeai/app/services/item_storage/item_storage_base.py index 1446c0cd08..e94c049ee4 100644 --- a/invokeai/app/services/item_storage/item_storage_base.py +++ b/invokeai/app/services/item_storage/item_storage_base.py @@ -15,8 +15,8 @@ class ItemStorageABC(ABC, Generic[T]): _on_deleted_callbacks: list[Callable[[str], None]] def __init__(self) -> None: - self._on_changed_callbacks = list() - self._on_deleted_callbacks = list() + self._on_changed_callbacks = [] + self._on_deleted_callbacks = [] """Base item storage class""" diff --git a/invokeai/app/services/item_storage/item_storage_sqlite.py b/invokeai/app/services/item_storage/item_storage_sqlite.py index d0249ebfa6..d5a1b7f730 100644 --- a/invokeai/app/services/item_storage/item_storage_sqlite.py +++ b/invokeai/app/services/item_storage/item_storage_sqlite.py @@ -112,7 +112,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): ) result = self._cursor.fetchall() - items = list(map(lambda r: self._parse_item(r[0]), result)) + items = [self._parse_item(r[0]) for r in result] self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""") count = self._cursor.fetchone()[0] @@ -132,7 +132,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): ) result = self._cursor.fetchall() - items = list(map(lambda r: self._parse_item(r[0]), result)) + items = [self._parse_item(r[0]) for r in result] self._cursor.execute( f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""", diff --git a/invokeai/app/services/latents_storage/latents_storage_base.py b/invokeai/app/services/latents_storage/latents_storage_base.py index 4850a477d3..9fa42b0ae6 100644 --- a/invokeai/app/services/latents_storage/latents_storage_base.py +++ b/invokeai/app/services/latents_storage/latents_storage_base.py @@ -13,8 +13,8 @@ class LatentsStorageBase(ABC): _on_deleted_callbacks: list[Callable[[str], None]] def __init__(self) -> None: - self._on_changed_callbacks = list() - self._on_deleted_callbacks = list() + self._on_changed_callbacks = [] + self._on_deleted_callbacks = [] @abstractmethod def get(self, name: str) -> torch.Tensor: diff --git a/invokeai/app/services/latents_storage/latents_storage_forward_cache.py b/invokeai/app/services/latents_storage/latents_storage_forward_cache.py index 5248362ff5..da82b5904d 100644 --- a/invokeai/app/services/latents_storage/latents_storage_forward_cache.py +++ b/invokeai/app/services/latents_storage/latents_storage_forward_cache.py @@ -19,7 +19,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase): def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20): super().__init__() self.__underlying_storage = underlying_storage - self.__cache = dict() + self.__cache = {} self.__cache_ids = Queue() self.__max_cache_size = max_cache_size diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 5b59dee254..28591fd7df 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -33,9 +33,11 @@ class DefaultSessionProcessor(SessionProcessorBase): self.__thread = Thread( name="session_processor", target=self.__process, - kwargs=dict( - stop_event=self.__stop_event, poll_now_event=self.__poll_now_event, resume_event=self.__resume_event - ), + kwargs={ + "stop_event": self.__stop_event, + "poll_now_event": self.__poll_now_event, + "resume_event": self.__resume_event, + }, ) self.__thread.start() diff --git a/invokeai/app/services/session_queue/session_queue_common.py b/invokeai/app/services/session_queue/session_queue_common.py index 69e6a3ab87..e7d7cdda46 100644 --- a/invokeai/app/services/session_queue/session_queue_common.py +++ b/invokeai/app/services/session_queue/session_queue_common.py @@ -129,12 +129,12 @@ class Batch(BaseModel): return v model_config = ConfigDict( - json_schema_extra=dict( - required=[ + json_schema_extra={ + "required": [ "graph", "runs", ] - ) + } ) @@ -191,8 +191,8 @@ class SessionQueueItemWithoutGraph(BaseModel): return SessionQueueItemDTO(**queue_item_dict) model_config = ConfigDict( - json_schema_extra=dict( - required=[ + json_schema_extra={ + "required": [ "item_id", "status", "batch_id", @@ -203,7 +203,7 @@ class SessionQueueItemWithoutGraph(BaseModel): "created_at", "updated_at", ] - ) + } ) @@ -222,8 +222,8 @@ class SessionQueueItem(SessionQueueItemWithoutGraph): return SessionQueueItem(**queue_item_dict) model_config = ConfigDict( - json_schema_extra=dict( - required=[ + json_schema_extra={ + "required": [ "item_id", "status", "batch_id", @@ -235,7 +235,7 @@ class SessionQueueItem(SessionQueueItemWithoutGraph): "created_at", "updated_at", ] - ) + } ) @@ -355,7 +355,7 @@ def create_session_nfv_tuples( for item in batch_datum.items ] node_field_values_to_zip.append(node_field_values) - data.append(list(zip(*node_field_values_to_zip))) # type: ignore [arg-type] + data.append(list(zip(*node_field_values_to_zip, strict=True))) # type: ignore [arg-type] # create generator to yield session,nfv tuples count = 0 @@ -383,7 +383,7 @@ def calc_session_count(batch: Batch) -> int: for batch_datum in batch_datum_list: batch_data_items = range(len(batch_datum.items)) to_zip.append(batch_data_items) - data.append(list(zip(*to_zip))) + data.append(list(zip(*to_zip, strict=True))) data_product = list(product(*data)) return len(data_product) * batch.runs diff --git a/invokeai/app/services/shared/default_graphs.py b/invokeai/app/services/shared/default_graphs.py index 9a6e2456cb..7e62c6d0a1 100644 --- a/invokeai/app/services/shared/default_graphs.py +++ b/invokeai/app/services/shared/default_graphs.py @@ -78,7 +78,7 @@ def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[Li """Creates the default system graphs, or adds new versions if the old ones don't match""" # TODO: Uncomment this when we are ready to fix this up to prevent breaking changes - graphs: list[LibraryGraph] = list() + graphs: list[LibraryGraph] = [] text_to_image = graph_library.get(default_text_to_image_graph_id) diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index b84d456071..29af1e2333 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -352,7 +352,7 @@ class Graph(BaseModel): # Validate that all node ids are unique node_ids = [n.id for n in self.nodes.values()] - duplicate_node_ids = set([node_id for node_id in node_ids if node_ids.count(node_id) >= 2]) + duplicate_node_ids = {node_id for node_id in node_ids if node_ids.count(node_id) >= 2} if duplicate_node_ids: raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}") @@ -616,7 +616,7 @@ class Graph(BaseModel): self, node_path: str, prefix: Optional[str] = None ) -> list[tuple["Graph", Union[str, None], Edge]]: """Gets all input edges for a node along with the graph they are in and the graph's path""" - edges = list() + edges = [] # Return any input edges that appear in this graph edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path]) @@ -658,7 +658,7 @@ class Graph(BaseModel): self, node_path: str, prefix: Optional[str] = None ) -> list[tuple["Graph", Union[str, None], Edge]]: """Gets all output edges for a node along with the graph they are in and the graph's path""" - edges = list() + edges = [] # Return any input edges that appear in this graph edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path]) @@ -680,8 +680,8 @@ class Graph(BaseModel): new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None, ) -> bool: - inputs = list([e.source for e in self._get_input_edges(node_path, "collection")]) - outputs = list([e.destination for e in self._get_output_edges(node_path, "item")]) + inputs = [e.source for e in self._get_input_edges(node_path, "collection")] + outputs = [e.destination for e in self._get_output_edges(node_path, "item")] if new_input is not None: inputs.append(new_input) @@ -694,7 +694,7 @@ class Graph(BaseModel): # Get input and output fields (the fields linked to the iterator's input/output) input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field) - output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs]) + output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs] # Input type must be a list if get_origin(input_field) != list: @@ -713,8 +713,8 @@ class Graph(BaseModel): new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None, ) -> bool: - inputs = list([e.source for e in self._get_input_edges(node_path, "item")]) - outputs = list([e.destination for e in self._get_output_edges(node_path, "collection")]) + inputs = [e.source for e in self._get_input_edges(node_path, "item")] + outputs = [e.destination for e in self._get_output_edges(node_path, "collection")] if new_input is not None: inputs.append(new_input) @@ -722,18 +722,16 @@ class Graph(BaseModel): outputs.append(new_output) # Get input and output fields (the fields linked to the iterator's input/output) - input_fields = list([get_output_field(self.get_node(e.node_id), e.field) for e in inputs]) - output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs]) + input_fields = [get_output_field(self.get_node(e.node_id), e.field) for e in inputs] + output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs] # Validate that all inputs are derived from or match a single type - input_field_types = set( - [ - t - for input_field in input_fields - for t in ([input_field] if get_origin(input_field) is None else get_args(input_field)) - if t != NoneType - ] - ) # Get unique types + input_field_types = { + t + for input_field in input_fields + for t in ([input_field] if get_origin(input_field) is None else get_args(input_field)) + if t != NoneType + } # Get unique types type_tree = nx.DiGraph() type_tree.add_nodes_from(input_field_types) type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])]) @@ -761,15 +759,15 @@ class Graph(BaseModel): """Returns a NetworkX DiGraph representing the layout of this graph""" # TODO: Cache this? g = nx.DiGraph() - g.add_nodes_from([n for n in self.nodes.keys()]) - g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges])) + g.add_nodes_from(list(self.nodes.keys())) + g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges}) return g def nx_graph_with_data(self) -> nx.DiGraph: """Returns a NetworkX DiGraph representing the data and layout of this graph""" g = nx.DiGraph() - g.add_nodes_from([n for n in self.nodes.items()]) - g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges])) + g.add_nodes_from(list(self.nodes.items())) + g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges}) return g def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph: @@ -791,7 +789,7 @@ class Graph(BaseModel): # TODO: figure out if iteration nodes need to be expanded - unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges]) + unique_edges = {(e.source.node_id, e.destination.node_id) for e in self.edges} g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges]) return g @@ -843,8 +841,8 @@ class GraphExecutionState(BaseModel): return v model_config = ConfigDict( - json_schema_extra=dict( - required=[ + json_schema_extra={ + "required": [ "id", "graph", "execution_graph", @@ -855,7 +853,7 @@ class GraphExecutionState(BaseModel): "prepared_source_mapping", "source_prepared_mapping", ] - ) + } ) def next(self) -> Optional[BaseInvocation]: @@ -895,7 +893,7 @@ class GraphExecutionState(BaseModel): source_node = self.prepared_source_mapping[node_id] prepared_nodes = self.source_prepared_mapping[source_node] - if all([n in self.executed for n in prepared_nodes]): + if all(n in self.executed for n in prepared_nodes): self.executed.add(source_node) self.executed_history.append(source_node) @@ -930,7 +928,7 @@ class GraphExecutionState(BaseModel): input_collection = getattr(input_collection_prepared_node_output, input_collection_edge.source.field) self_iteration_count = len(input_collection) - new_nodes: list[str] = list() + new_nodes: list[str] = [] if self_iteration_count == 0: # TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid. return new_nodes @@ -940,7 +938,7 @@ class GraphExecutionState(BaseModel): # Create new edges for this iteration # For collect nodes, this may contain multiple inputs to the same field - new_edges: list[Edge] = list() + new_edges: list[Edge] = [] for edge in input_edges: for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge.source.node_id): new_edge = Edge( @@ -1034,7 +1032,7 @@ class GraphExecutionState(BaseModel): # Create execution nodes next_node = self.graph.get_node(next_node_id) - new_node_ids = list() + new_node_ids = [] if isinstance(next_node, CollectInvocation): # Collapse all iterator input mappings and create a single execution node for the collect invocation all_iteration_mappings = list( @@ -1055,7 +1053,10 @@ class GraphExecutionState(BaseModel): # For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator # TODO: Handle a node mapping to none eg = self.execution_graph.nx_graph_flat() - prepared_parent_mappings = [[(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] for it in iterator_node_prepared_combinations] # type: ignore + prepared_parent_mappings = [ + [(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] + for it in iterator_node_prepared_combinations + ] # type: ignore # Create execution node for each iteration for iteration_mappings in prepared_parent_mappings: @@ -1121,7 +1122,7 @@ class GraphExecutionState(BaseModel): for edge in input_edges if edge.destination.field == "item" ] - setattr(node, "collection", output_collection) + node.collection = output_collection else: for edge in input_edges: output_value = getattr(self.results[edge.source.node_id], edge.source.field) @@ -1201,7 +1202,7 @@ class LibraryGraph(BaseModel): @field_validator("exposed_inputs", "exposed_outputs") def validate_exposed_aliases(cls, v: list[Union[ExposedNodeInput, ExposedNodeOutput]]): - if len(v) != len(set(i.alias for i in v)): + if len(v) != len({i.alias for i in v}): raise ValueError("Duplicate exposed alias") return v diff --git a/invokeai/app/util/controlnet_utils.py b/invokeai/app/util/controlnet_utils.py index 51ceec2edd..b3e2560211 100644 --- a/invokeai/app/util/controlnet_utils.py +++ b/invokeai/app/util/controlnet_utils.py @@ -59,7 +59,7 @@ def thin_one_time(x, kernels): def lvmin_thin(x, prunings=True): y = x - for i in range(32): + for _i in range(32): y, is_done = thin_one_time(y, lvmin_kernels) if is_done: break diff --git a/invokeai/app/util/metadata.py b/invokeai/app/util/metadata.py index 15951cb009..52f9750e4f 100644 --- a/invokeai/app/util/metadata.py +++ b/invokeai/app/util/metadata.py @@ -21,11 +21,11 @@ def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]: # sanity check make sure the graph is at least reasonably shaped if ( - type(graph) is not dict + not isinstance(graph, dict) or "nodes" not in graph - or type(graph["nodes"]) is not dict + or not isinstance(graph["nodes"], dict) or "edges" not in graph - or type(graph["edges"]) is not list + or not isinstance(graph["edges"], list) ): # something has gone terribly awry, return an empty dict return None diff --git a/invokeai/backend/image_util/pngwriter.py b/invokeai/backend/image_util/pngwriter.py index 47f6a44c28..c9c58264c2 100644 --- a/invokeai/backend/image_util/pngwriter.py +++ b/invokeai/backend/image_util/pngwriter.py @@ -88,7 +88,7 @@ class PromptFormatter: t2i = self.t2i opt = self.opt - switches = list() + switches = [] switches.append(f'"{opt.prompt}"') switches.append(f"-s{opt.steps or t2i.steps}") switches.append(f"-W{opt.width or t2i.width}") diff --git a/invokeai/backend/image_util/txt2mask.py b/invokeai/backend/image_util/txt2mask.py index de0c6a1652..5cbc7c1e38 100644 --- a/invokeai/backend/image_util/txt2mask.py +++ b/invokeai/backend/image_util/txt2mask.py @@ -88,7 +88,7 @@ class Txt2Mask(object): provided image and returns a SegmentedGrayscale object in which the brighter pixels indicate where the object is inferred to be. """ - if type(image) is str: + if isinstance(image, str): image = Image.open(image).convert("RGB") image = ImageOps.exif_transpose(image) diff --git a/invokeai/backend/image_util/util.py b/invokeai/backend/image_util/util.py index 7eceb9be82..5b8be7f118 100644 --- a/invokeai/backend/image_util/util.py +++ b/invokeai/backend/image_util/util.py @@ -40,7 +40,7 @@ class InitImageResizer: (rw, rh) = (int(scale * im.width), int(scale * im.height)) # round everything to multiples of 64 - width, height, rw, rh = map(lambda x: x - x % 64, (width, height, rw, rh)) + width, height, rw, rh = (x - x % 64 for x in (width, height, rw, rh)) # no resize necessary, but return a copy if im.width == width and im.height == height: diff --git a/invokeai/backend/install/invokeai_configure.py b/invokeai/backend/install/invokeai_configure.py index 28910e8942..2621b811ac 100755 --- a/invokeai/backend/install/invokeai_configure.py +++ b/invokeai/backend/install/invokeai_configure.py @@ -197,7 +197,7 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th def download_conversion_models(): target_dir = config.models_path / "core/convert" - kwargs = dict() # for future use + kwargs = {} # for future use try: logger.info("Downloading core tokenizers and text encoders") @@ -252,26 +252,26 @@ def download_conversion_models(): def download_realesrgan(): logger.info("Installing ESRGAN Upscaling models...") URLs = [ - dict( - url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", - dest="core/upscaling/realesrgan/RealESRGAN_x4plus.pth", - description="RealESRGAN_x4plus.pth", - ), - dict( - url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", - dest="core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth", - description="RealESRGAN_x4plus_anime_6B.pth", - ), - dict( - url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", - dest="core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", - description="ESRGAN_SRx4_DF2KOST_official.pth", - ), - dict( - url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", - dest="core/upscaling/realesrgan/RealESRGAN_x2plus.pth", - description="RealESRGAN_x2plus.pth", - ), + { + "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", + "dest": "core/upscaling/realesrgan/RealESRGAN_x4plus.pth", + "description": "RealESRGAN_x4plus.pth", + }, + { + "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", + "dest": "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth", + "description": "RealESRGAN_x4plus_anime_6B.pth", + }, + { + "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", + "dest": "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", + "description": "ESRGAN_SRx4_DF2KOST_official.pth", + }, + { + "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", + "dest": "core/upscaling/realesrgan/RealESRGAN_x2plus.pth", + "description": "RealESRGAN_x2plus.pth", + }, ] for model in URLs: download_with_progress_bar(model["url"], config.models_path / model["dest"], model["description"]) @@ -680,7 +680,7 @@ def default_user_selections(program_opts: Namespace) -> InstallSelections: if program_opts.default_only else [models[x].path or models[x].repo_id for x in installer.recommended_models()] if program_opts.yes_to_all - else list(), + else [], ) diff --git a/invokeai/backend/install/migrate_to_3.py b/invokeai/backend/install/migrate_to_3.py index ea5bee8058..e15eb23f5b 100644 --- a/invokeai/backend/install/migrate_to_3.py +++ b/invokeai/backend/install/migrate_to_3.py @@ -123,8 +123,6 @@ class MigrateTo3(object): logger.error(str(e)) except KeyboardInterrupt: raise - except Exception as e: - logger.error(str(e)) for f in files: # don't copy raw learned_embeds.bin or pytorch_lora_weights.bin # let them be copied as part of a tree copy operation @@ -143,8 +141,6 @@ class MigrateTo3(object): logger.error(str(e)) except KeyboardInterrupt: raise - except Exception as e: - logger.error(str(e)) def migrate_support_models(self): """ @@ -182,10 +178,10 @@ class MigrateTo3(object): """ dest_directory = self.dest_models - kwargs = dict( - cache_dir=self.root_directory / "models/hub", + kwargs = { + "cache_dir": self.root_directory / "models/hub", # local_files_only = True - ) + } try: logger.info("Migrating core tokenizers and text encoders") target_dir = dest_directory / "core" / "convert" @@ -316,11 +312,11 @@ class MigrateTo3(object): dest_dir = self.dest_models cache = self.root_directory / "models/hub" - kwargs = dict( - cache_dir=cache, - safety_checker=None, + kwargs = { + "cache_dir": cache, + "safety_checker": None, # local_files_only = True, - ) + } owner, repo_name = repo_id.split("/") model_name = model_name or repo_name diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py index 34526cfaf3..afbcc848d8 100644 --- a/invokeai/backend/install/model_install_backend.py +++ b/invokeai/backend/install/model_install_backend.py @@ -120,7 +120,7 @@ class ModelInstall(object): be treated uniformly. It also sorts the models alphabetically by their name, to improve the display somewhat. """ - model_dict = dict() + model_dict = {} # first populate with the entries in INITIAL_MODELS.yaml for key, value in self.datasets.items(): @@ -134,7 +134,7 @@ class ModelInstall(object): model_dict[key] = model_info # supplement with entries in models.yaml - installed_models = [x for x in self.mgr.list_models()] + installed_models = list(self.mgr.list_models()) for md in installed_models: base = md["base_model"] @@ -176,7 +176,7 @@ class ModelInstall(object): # logic here a little reversed to maintain backward compatibility def starter_models(self, all_models: bool = False) -> Set[str]: models = set() - for key, value in self.datasets.items(): + for key, _value in self.datasets.items(): name, base, model_type = ModelManager.parse_key(key) if all_models or model_type in [ModelType.Main, ModelType.Vae]: models.add(key) @@ -184,7 +184,7 @@ class ModelInstall(object): def recommended_models(self) -> Set[str]: starters = self.starter_models(all_models=True) - return set([x for x in starters if self.datasets[x].get("recommended", False)]) + return {x for x in starters if self.datasets[x].get("recommended", False)} def default_model(self) -> str: starters = self.starter_models() @@ -234,7 +234,7 @@ class ModelInstall(object): """ if not models_installed: - models_installed = dict() + models_installed = {} model_path_id_or_url = str(model_path_id_or_url).strip("\"' ") @@ -252,16 +252,14 @@ class ModelInstall(object): # folders style or similar elif path.is_dir() and any( - [ - (path / x).exists() - for x in { - "config.json", - "model_index.json", - "learned_embeds.bin", - "pytorch_lora_weights.bin", - "pytorch_lora_weights.safetensors", - } - ] + (path / x).exists() + for x in { + "config.json", + "model_index.json", + "learned_embeds.bin", + "pytorch_lora_weights.bin", + "pytorch_lora_weights.safetensors", + } ): models_installed.update({str(model_path_id_or_url): self._install_path(path)}) @@ -433,17 +431,17 @@ class ModelInstall(object): rel_path = self.relative_to_root(path, self.config.models_path) - attributes = dict( - path=str(rel_path), - description=str(description), - model_format=info.format, - ) + attributes = { + "path": str(rel_path), + "description": str(description), + "model_format": info.format, + } legacy_conf = None if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX: attributes.update( - dict( - variant=info.variant_type, - ) + { + "variant": info.variant_type, + } ) if info.format == "checkpoint": try: @@ -474,7 +472,7 @@ class ModelInstall(object): ) if legacy_conf: - attributes.update(dict(config=str(legacy_conf))) + attributes.update({"config": str(legacy_conf)}) return attributes def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path: @@ -519,7 +517,7 @@ class ModelInstall(object): def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path: _, name = repo_id.split("/") location = staging / name - paths = list() + paths = [] for filename in files: filePath = Path(filename) p = hf_download_with_resume( diff --git a/invokeai/backend/ip_adapter/attention_processor.py b/invokeai/backend/ip_adapter/attention_processor.py index c06d7d113c..195cb12d1b 100644 --- a/invokeai/backend/ip_adapter/attention_processor.py +++ b/invokeai/backend/ip_adapter/attention_processor.py @@ -130,7 +130,9 @@ class IPAttnProcessor2_0(torch.nn.Module): assert ip_adapter_image_prompt_embeds is not None assert len(ip_adapter_image_prompt_embeds) == len(self._weights) - for ipa_embed, ipa_weights, scale in zip(ip_adapter_image_prompt_embeds, self._weights, self._scales): + for ipa_embed, ipa_weights, scale in zip( + ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True + ): # The batch dimensions should match. assert ipa_embed.shape[0] == encoder_hidden_states.shape[0] # The token_len dimensions should match. diff --git a/invokeai/backend/ip_adapter/resampler.py b/invokeai/backend/ip_adapter/resampler.py index 84224fd359..a8db22c0fd 100644 --- a/invokeai/backend/ip_adapter/resampler.py +++ b/invokeai/backend/ip_adapter/resampler.py @@ -56,7 +56,7 @@ class PerceiverAttention(nn.Module): x = self.norm1(x) latents = self.norm2(latents) - b, l, _ = latents.shape + b, L, _ = latents.shape q = self.to_q(latents) kv_input = torch.cat((x, latents), dim=-2) @@ -72,7 +72,7 @@ class PerceiverAttention(nn.Module): weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) out = weight @ v - out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + out = out.permute(0, 2, 1, 3).reshape(b, L, -1) return self.to_out(out) diff --git a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py index 0a3a63dad6..1cecfb1a72 100644 --- a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py +++ b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py @@ -269,7 +269,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa resolution *= 2 up_block_types = [] - for i in range(len(block_out_channels)): + for _i in range(len(block_out_channels)): block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" up_block_types.append(block_type) resolution //= 2 @@ -1223,7 +1223,7 @@ def download_from_original_stable_diffusion_ckpt( # scan model scan_result = scan_file_path(checkpoint_path) if scan_result.infected_files != 0: - raise "The model {checkpoint_path} is potentially infected by malware. Aborting import." + raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.") if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" checkpoint = torch.load(checkpoint_path, map_location=device) @@ -1664,7 +1664,7 @@ def download_controlnet_from_original_ckpt( # scan model scan_result = scan_file_path(checkpoint_path) if scan_result.infected_files != 0: - raise "The model {checkpoint_path} is potentially infected by malware. Aborting import." + raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.") if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" checkpoint = torch.load(checkpoint_path, map_location=device) diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index e4caf60aac..4389cacacc 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -104,7 +104,7 @@ class ModelPatcher: loras: List[Tuple[LoRAModel, float]], prefix: str, ): - original_weights = dict() + original_weights = {} try: with torch.no_grad(): for lora, lora_weight in loras: @@ -242,7 +242,7 @@ class ModelPatcher: ): skipped_layers = [] try: - for i in range(clip_skip): + for _i in range(clip_skip): skipped_layers.append(text_encoder.text_model.encoder.layers.pop(-1)) yield @@ -324,7 +324,7 @@ class TextualInversionManager(BaseTextualInversionManager): tokenizer: CLIPTokenizer def __init__(self, tokenizer: CLIPTokenizer): - self.pad_tokens = dict() + self.pad_tokens = {} self.tokenizer = tokenizer def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]: @@ -385,10 +385,10 @@ class ONNXModelPatcher: if not isinstance(model, IAIOnnxRuntimeModel): raise Exception("Only IAIOnnxRuntimeModel models supported") - orig_weights = dict() + orig_weights = {} try: - blended_loras = dict() + blended_loras = {} for lora, lora_weight in loras: for layer_key, layer in lora.layers.items(): @@ -404,7 +404,7 @@ class ONNXModelPatcher: else: blended_loras[layer_key] = layer_weight - node_names = dict() + node_names = {} for node in model.nodes.values(): node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index 83af789219..2a7f4b5a95 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -66,11 +66,13 @@ class CacheStats(object): class ModelLocker(object): "Forward declaration" + pass class ModelCache(object): "Forward declaration" + pass @@ -132,7 +134,7 @@ class ModelCache(object): snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's behaviour. """ - self.model_infos: Dict[str, ModelBase] = dict() + self.model_infos: Dict[str, ModelBase] = {} # allow lazy offloading only when vram cache enabled self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0 self.precision: torch.dtype = precision @@ -147,8 +149,8 @@ class ModelCache(object): # used for stats collection self.stats = None - self._cached_models = dict() - self._cache_stack = list() + self._cached_models = {} + self._cache_stack = [] def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]: if self._log_memory_usage: diff --git a/invokeai/backend/model_management/model_load_optimizations.py b/invokeai/backend/model_management/model_load_optimizations.py index 8dc8a8793e..a46d262175 100644 --- a/invokeai/backend/model_management/model_load_optimizations.py +++ b/invokeai/backend/model_management/model_load_optimizations.py @@ -26,5 +26,5 @@ def skip_torch_weight_init(): yield None finally: - for torch_module, saved_function in zip(torch_modules, saved_functions): + for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True): torch_module.reset_parameters = saved_function diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index da4239fa07..e9f498a438 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -363,7 +363,7 @@ class ModelManager(object): else: return - self.models = dict() + self.models = {} for model_key, model_config in config.items(): if model_key.startswith("_"): continue @@ -374,7 +374,7 @@ class ModelManager(object): self.models[model_key] = model_class.create_config(**model_config) # check config version number and update on disk/RAM if necessary - self.cache_keys = dict() + self.cache_keys = {} # add controlnet, lora and textual_inversion models from disk self.scan_models_directory() @@ -655,7 +655,7 @@ class ModelManager(object): """ # TODO: redo for model_dict in self.list_models(): - for model_name, model_info in model_dict.items(): + for _model_name, model_info in model_dict.items(): line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}' print(line) @@ -902,7 +902,7 @@ class ModelManager(object): """ Write current configuration out to the indicated file. """ - data_to_save = dict() + data_to_save = {} data_to_save["__metadata__"] = self.config_meta.model_dump() for model_key, model_config in self.models.items(): @@ -1034,7 +1034,7 @@ class ModelManager(object): self.ignore = ignore def on_search_started(self): - self.new_models_found = dict() + self.new_models_found = {} def on_model_found(self, model: Path): if model not in self.ignore: @@ -1106,7 +1106,7 @@ class ModelManager(object): # avoid circular import here from invokeai.backend.install.model_install_backend import ModelInstall - successfully_installed = dict() + successfully_installed = {} installer = ModelInstall( config=self.app_config, prediction_type_helper=prediction_type_helper, model_manager=self diff --git a/invokeai/backend/model_management/model_merge.py b/invokeai/backend/model_management/model_merge.py index 59201d64d9..a9f0a23618 100644 --- a/invokeai/backend/model_management/model_merge.py +++ b/invokeai/backend/model_management/model_merge.py @@ -92,7 +92,7 @@ class ModelMerger(object): **kwargs - the default DiffusionPipeline.get_config_dict kwargs: cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map """ - model_paths = list() + model_paths = [] config = self.manager.app_config base_model = BaseModelType(base_model) vae = None @@ -124,13 +124,13 @@ class ModelMerger(object): dump_path = (dump_path / merged_model_name).as_posix() merged_pipe.save_pretrained(dump_path, safe_serialization=True) - attributes = dict( - path=dump_path, - description=f"Merge of models {', '.join(model_names)}", - model_format="diffusers", - variant=ModelVariantType.Normal.value, - vae=vae, - ) + attributes = { + "path": dump_path, + "description": f"Merge of models {', '.join(model_names)}", + "model_format": "diffusers", + "variant": ModelVariantType.Normal.value, + "vae": vae, + } return self.manager.add_model( merged_model_name, base_model=base_model, diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management/model_probe.py index e82be8d069..83d3d610c3 100644 --- a/invokeai/backend/model_management/model_probe.py +++ b/invokeai/backend/model_management/model_probe.py @@ -237,7 +237,7 @@ class ModelProbe(object): # scan model scan_result = scan_file_path(checkpoint) if scan_result.infected_files != 0: - raise "The model {model_name} is potentially infected by malware. Aborting import." + raise Exception("The model {model_name} is potentially infected by malware. Aborting import.") # ##################################################3 diff --git a/invokeai/backend/model_management/model_search.py b/invokeai/backend/model_management/model_search.py index 7e6b37c832..e125c3ced7 100644 --- a/invokeai/backend/model_management/model_search.py +++ b/invokeai/backend/model_management/model_search.py @@ -59,7 +59,7 @@ class ModelSearch(ABC): for root, dirs, files in os.walk(path, followlinks=True): if str(Path(root).name).startswith("."): self._pruned_paths.add(root) - if any([Path(root).is_relative_to(x) for x in self._pruned_paths]): + if any(Path(root).is_relative_to(x) for x in self._pruned_paths): continue self._items_scanned += len(dirs) + len(files) @@ -69,16 +69,14 @@ class ModelSearch(ABC): self._scanned_dirs.add(path) continue if any( - [ - (path / x).exists() - for x in { - "config.json", - "model_index.json", - "learned_embeds.bin", - "pytorch_lora_weights.bin", - "image_encoder.txt", - } - ] + (path / x).exists() + for x in { + "config.json", + "model_index.json", + "learned_embeds.bin", + "pytorch_lora_weights.bin", + "image_encoder.txt", + } ): try: self.on_model_found(path) diff --git a/invokeai/backend/model_management/models/__init__.py b/invokeai/backend/model_management/models/__init__.py index 0afd731032..5f9b13b96f 100644 --- a/invokeai/backend/model_management/models/__init__.py +++ b/invokeai/backend/model_management/models/__init__.py @@ -97,8 +97,8 @@ MODEL_CLASSES = { # }, } -MODEL_CONFIGS = list() -OPENAPI_MODEL_CONFIGS = list() +MODEL_CONFIGS = [] +OPENAPI_MODEL_CONFIGS = [] class OpenAPIModelInfoBase(BaseModel): @@ -109,7 +109,7 @@ class OpenAPIModelInfoBase(BaseModel): model_config = ConfigDict(protected_namespaces=()) -for base_model, models in MODEL_CLASSES.items(): +for _base_model, models in MODEL_CLASSES.items(): for model_type, model_class in models.items(): model_configs = set(model_class._get_configs().values()) model_configs.discard(None) @@ -133,7 +133,7 @@ for base_model, models in MODEL_CLASSES.items(): def get_model_config_enums(): - enums = list() + enums = [] for model_config in MODEL_CONFIGS: if hasattr(inspect, "get_annotations"): diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index f735e37189..7807cb9a54 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -153,7 +153,7 @@ class ModelBase(metaclass=ABCMeta): else: res_type = sys.modules["diffusers"] - res_type = getattr(res_type, "pipelines") + res_type = res_type.pipelines for subtype in subtypes: res_type = getattr(res_type, subtype) @@ -164,7 +164,7 @@ class ModelBase(metaclass=ABCMeta): with suppress(Exception): return cls.__configs - configs = dict() + configs = {} for name in dir(cls): if name.startswith("__"): continue @@ -246,8 +246,8 @@ class DiffusersModel(ModelBase): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): super().__init__(model_path, base_model, model_type) - self.child_types: Dict[str, Type] = dict() - self.child_sizes: Dict[str, int] = dict() + self.child_types: Dict[str, Type] = {} + self.child_sizes: Dict[str, int] = {} try: config_data = DiffusionPipeline.load_config(self.model_path) @@ -326,8 +326,8 @@ def calc_model_size_by_fs(model_path: str, subfolder: Optional[str] = None, vari all_files = os.listdir(model_path) all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))] - fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f]) - bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f]) + fp16_files = {f for f in all_files if ".fp16." in f or ".fp16-" in f} + bit8_files = {f for f in all_files if ".8bit." in f or ".8bit-" in f} other_files = set(all_files) - fp16_files - bit8_files if variant is None: @@ -413,7 +413,7 @@ def _calc_onnx_model_by_data(model) -> int: def _fast_safetensors_reader(path: str): - checkpoint = dict() + checkpoint = {} device = torch.device("meta") with open(path, "rb") as f: definition_len = int.from_bytes(f.read(8), "little") @@ -483,7 +483,7 @@ class IAIOnnxRuntimeModel: class _tensor_access: def __init__(self, model): self.model = model - self.indexes = dict() + self.indexes = {} for idx, obj in enumerate(self.model.proto.graph.initializer): self.indexes[obj.name] = idx @@ -524,7 +524,7 @@ class IAIOnnxRuntimeModel: class _access_helper: def __init__(self, raw_proto): - self.indexes = dict() + self.indexes = {} self.raw_proto = raw_proto for idx, obj in enumerate(raw_proto): self.indexes[obj.name] = idx @@ -549,7 +549,7 @@ class IAIOnnxRuntimeModel: return self.indexes.keys() def values(self): - return [obj for obj in self.raw_proto] + return list(self.raw_proto) def __init__(self, model_path: str, provider: Optional[str]): self.path = model_path diff --git a/invokeai/backend/model_management/models/controlnet.py b/invokeai/backend/model_management/models/controlnet.py index 6a42b59fe1..da269eba4b 100644 --- a/invokeai/backend/model_management/models/controlnet.py +++ b/invokeai/backend/model_management/models/controlnet.py @@ -104,7 +104,7 @@ class ControlNetModel(ModelBase): return ControlNetModelFormat.Diffusers if os.path.isfile(path): - if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]]): + if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]): return ControlNetModelFormat.Checkpoint raise InvalidModelException(f"Not a valid model: {path}") diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_management/models/lora.py index 43a24275d1..b110d75d22 100644 --- a/invokeai/backend/model_management/models/lora.py +++ b/invokeai/backend/model_management/models/lora.py @@ -73,7 +73,7 @@ class LoRAModel(ModelBase): return LoRAModelFormat.Diffusers if os.path.isfile(path): - if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]): + if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]): return LoRAModelFormat.LyCORIS raise InvalidModelException(f"Not a valid model: {path}") @@ -462,7 +462,7 @@ class LoRAModelRaw: # (torch.nn.Module): dtype: Optional[torch.dtype] = None, ): # TODO: try revert if exception? - for key, layer in self.layers.items(): + for _key, layer in self.layers.items(): layer.to(device=device, dtype=dtype) def calc_size(self) -> int: @@ -499,7 +499,7 @@ class LoRAModelRaw: # (torch.nn.Module): stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP) stability_unet_keys.sort() - new_state_dict = dict() + new_state_dict = {} for full_key, value in state_dict.items(): if full_key.startswith("lora_unet_"): search_key = full_key.replace("lora_unet_", "") @@ -545,7 +545,7 @@ class LoRAModelRaw: # (torch.nn.Module): model = cls( name=file_path.stem, # TODO: - layers=dict(), + layers={}, ) if file_path.suffix == ".safetensors": @@ -593,12 +593,12 @@ class LoRAModelRaw: # (torch.nn.Module): @staticmethod def _group_state(state_dict: dict): - state_dict_groupped = dict() + state_dict_groupped = {} for key, value in state_dict.items(): stem, leaf = key.split(".", 1) if stem not in state_dict_groupped: - state_dict_groupped[stem] = dict() + state_dict_groupped[stem] = {} state_dict_groupped[stem][leaf] = value return state_dict_groupped diff --git a/invokeai/backend/model_management/models/stable_diffusion.py b/invokeai/backend/model_management/models/stable_diffusion.py index ffce42d9e9..a38a44fccf 100644 --- a/invokeai/backend/model_management/models/stable_diffusion.py +++ b/invokeai/backend/model_management/models/stable_diffusion.py @@ -110,7 +110,7 @@ class StableDiffusion1Model(DiffusersModel): return StableDiffusion1ModelFormat.Diffusers if os.path.isfile(model_path): - if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]): + if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]): return StableDiffusion1ModelFormat.Checkpoint raise InvalidModelException(f"Not a valid model: {model_path}") @@ -221,7 +221,7 @@ class StableDiffusion2Model(DiffusersModel): return StableDiffusion2ModelFormat.Diffusers if os.path.isfile(model_path): - if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]): + if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]): return StableDiffusion2ModelFormat.Checkpoint raise InvalidModelException(f"Not a valid model: {model_path}") diff --git a/invokeai/backend/model_management/models/textual_inversion.py b/invokeai/backend/model_management/models/textual_inversion.py index b59e635045..99358704b8 100644 --- a/invokeai/backend/model_management/models/textual_inversion.py +++ b/invokeai/backend/model_management/models/textual_inversion.py @@ -71,7 +71,7 @@ class TextualInversionModel(ModelBase): return None # diffusers-ti if os.path.isfile(path): - if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]]): + if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]): return None raise InvalidModelException(f"Not a valid model: {path}") diff --git a/invokeai/backend/model_management/models/vae.py b/invokeai/backend/model_management/models/vae.py index 637160c69b..8cc37e67a7 100644 --- a/invokeai/backend/model_management/models/vae.py +++ b/invokeai/backend/model_management/models/vae.py @@ -89,7 +89,7 @@ class VaeModel(ModelBase): return VaeModelFormat.Diffusers if os.path.isfile(path): - if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]): + if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]): return VaeModelFormat.Checkpoint raise InvalidModelException(f"Not a valid model: {path}") diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 1b65326f6e..1353e804a7 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -193,6 +193,7 @@ class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput): attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user after generation completes. Optional. """ + attention_map_saver: Optional[AttentionMapSaver] diff --git a/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py b/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py index 3cb1862004..92a538ff70 100644 --- a/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py +++ b/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py @@ -54,13 +54,13 @@ class Context: self.clear_requests(cleanup=True) def register_cross_attention_modules(self, model): - for name, module in get_cross_attention_modules(model, CrossAttentionType.SELF): + for name, _module in get_cross_attention_modules(model, CrossAttentionType.SELF): if name in self.self_cross_attention_module_identifiers: - assert False, f"name {name} cannot appear more than once" + raise AssertionError(f"name {name} cannot appear more than once") self.self_cross_attention_module_identifiers.append(name) - for name, module in get_cross_attention_modules(model, CrossAttentionType.TOKENS): + for name, _module in get_cross_attention_modules(model, CrossAttentionType.TOKENS): if name in self.tokens_cross_attention_module_identifiers: - assert False, f"name {name} cannot appear more than once" + raise AssertionError(f"name {name} cannot appear more than once") self.tokens_cross_attention_module_identifiers.append(name) def request_save_attention_maps(self, cross_attention_type: CrossAttentionType): @@ -170,7 +170,7 @@ class Context: self.saved_cross_attention_maps = {} def offload_saved_attention_slices_to_cpu(self): - for key, map_dict in self.saved_cross_attention_maps.items(): + for _key, map_dict in self.saved_cross_attention_maps.items(): for offset, slice in map_dict["slices"].items(): map_dict[offset] = slice.to("cpu") @@ -433,7 +433,7 @@ def inject_attention_function(unet, context: Context): module.identifier = identifier try: module.set_attention_slice_wrangler(attention_slice_wrangler) - module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier)) + module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier)) # noqa: B023 except AttributeError as e: if is_attribute_error_about(e, "set_attention_slice_wrangler"): print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO @@ -445,7 +445,7 @@ def remove_attention_function(unet): cross_attention_modules = get_cross_attention_modules( unet, CrossAttentionType.TOKENS ) + get_cross_attention_modules(unet, CrossAttentionType.SELF) - for identifier, module in cross_attention_modules: + for _identifier, module in cross_attention_modules: try: # clear wrangler callback module.set_attention_slice_wrangler(None) diff --git a/invokeai/backend/stable_diffusion/diffusion/cross_attention_map_saving.py b/invokeai/backend/stable_diffusion/diffusion/cross_attention_map_saving.py index b5ea40185a..82c9f1dcea 100644 --- a/invokeai/backend/stable_diffusion/diffusion/cross_attention_map_saving.py +++ b/invokeai/backend/stable_diffusion/diffusion/cross_attention_map_saving.py @@ -56,7 +56,7 @@ class AttentionMapSaver: merged = None - for key, maps in self.collated_maps.items(): + for _key, maps in self.collated_maps.items(): # maps has shape [(H*W), N] for N tokens # but we want [N, H, W] this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height)) diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index b7c0058fe9..455e5e1096 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -123,7 +123,7 @@ class InvokeAIDiffuserComponent: # control_data should be type List[ControlNetData] # this loop covers both ControlNet (one ControlNetData in list) # and MultiControlNet (multiple ControlNetData in list) - for i, control_datum in enumerate(control_data): + for _i, control_datum in enumerate(control_data): control_mode = control_datum.control_mode # soft_injection and cfg_injection are the two ControlNet control_mode booleans # that are combined at higher level to make control_mode enum @@ -214,7 +214,7 @@ class InvokeAIDiffuserComponent: # add controlnet outputs together if have multiple controlnets down_block_res_samples = [ samples_prev + samples_curr - for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) + for samples_prev, samples_curr in zip(down_block_res_samples, down_samples, strict=True) ] mid_block_res_sample += mid_sample @@ -642,7 +642,9 @@ class InvokeAIDiffuserComponent: deltas = None uncond_latents = None - weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)] + weighted_cond_list = ( + c_or_weighted_c_list if isinstance(c_or_weighted_c_list, list) else [(c_or_weighted_c_list, 1)] + ) # below is fugly omg conditionings = [uc] + [c for c, weight in weighted_cond_list] diff --git a/invokeai/backend/stable_diffusion/schedulers/schedulers.py b/invokeai/backend/stable_diffusion/schedulers/schedulers.py index 65bb05b582..c824d94dca 100644 --- a/invokeai/backend/stable_diffusion/schedulers/schedulers.py +++ b/invokeai/backend/stable_diffusion/schedulers/schedulers.py @@ -16,28 +16,28 @@ from diffusers import ( UniPCMultistepScheduler, ) -SCHEDULER_MAP = dict( - ddim=(DDIMScheduler, dict()), - ddpm=(DDPMScheduler, dict()), - deis=(DEISMultistepScheduler, dict()), - lms=(LMSDiscreteScheduler, dict(use_karras_sigmas=False)), - lms_k=(LMSDiscreteScheduler, dict(use_karras_sigmas=True)), - pndm=(PNDMScheduler, dict()), - heun=(HeunDiscreteScheduler, dict(use_karras_sigmas=False)), - heun_k=(HeunDiscreteScheduler, dict(use_karras_sigmas=True)), - euler=(EulerDiscreteScheduler, dict(use_karras_sigmas=False)), - euler_k=(EulerDiscreteScheduler, dict(use_karras_sigmas=True)), - euler_a=(EulerAncestralDiscreteScheduler, dict()), - kdpm_2=(KDPM2DiscreteScheduler, dict()), - kdpm_2_a=(KDPM2AncestralDiscreteScheduler, dict()), - dpmpp_2s=(DPMSolverSinglestepScheduler, dict(use_karras_sigmas=False)), - dpmpp_2s_k=(DPMSolverSinglestepScheduler, dict(use_karras_sigmas=True)), - dpmpp_2m=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False)), - dpmpp_2m_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)), - dpmpp_2m_sde=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False, algorithm_type="sde-dpmsolver++")), - dpmpp_2m_sde_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True, algorithm_type="sde-dpmsolver++")), - dpmpp_sde=(DPMSolverSDEScheduler, dict(use_karras_sigmas=False, noise_sampler_seed=0)), - dpmpp_sde_k=(DPMSolverSDEScheduler, dict(use_karras_sigmas=True, noise_sampler_seed=0)), - unipc=(UniPCMultistepScheduler, dict(cpu_only=True)), - lcm=(LCMScheduler, dict()), -) +SCHEDULER_MAP = { + "ddim": (DDIMScheduler, {}), + "ddpm": (DDPMScheduler, {}), + "deis": (DEISMultistepScheduler, {}), + "lms": (LMSDiscreteScheduler, {"use_karras_sigmas": False}), + "lms_k": (LMSDiscreteScheduler, {"use_karras_sigmas": True}), + "pndm": (PNDMScheduler, {}), + "heun": (HeunDiscreteScheduler, {"use_karras_sigmas": False}), + "heun_k": (HeunDiscreteScheduler, {"use_karras_sigmas": True}), + "euler": (EulerDiscreteScheduler, {"use_karras_sigmas": False}), + "euler_k": (EulerDiscreteScheduler, {"use_karras_sigmas": True}), + "euler_a": (EulerAncestralDiscreteScheduler, {}), + "kdpm_2": (KDPM2DiscreteScheduler, {}), + "kdpm_2_a": (KDPM2AncestralDiscreteScheduler, {}), + "dpmpp_2s": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": False}), + "dpmpp_2s_k": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True}), + "dpmpp_2m": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False}), + "dpmpp_2m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True}), + "dpmpp_2m_sde": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "algorithm_type": "sde-dpmsolver++"}), + "dpmpp_2m_sde_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "algorithm_type": "sde-dpmsolver++"}), + "dpmpp_sde": (DPMSolverSDEScheduler, {"use_karras_sigmas": False, "noise_sampler_seed": 0}), + "dpmpp_sde_k": (DPMSolverSDEScheduler, {"use_karras_sigmas": True, "noise_sampler_seed": 0}), + "unipc": (UniPCMultistepScheduler, {"cpu_only": True}), + "lcm": (LCMScheduler, {}), +} diff --git a/invokeai/backend/training/textual_inversion_training.py b/invokeai/backend/training/textual_inversion_training.py index 9bc1d188bc..0d1b9a2d8c 100644 --- a/invokeai/backend/training/textual_inversion_training.py +++ b/invokeai/backend/training/textual_inversion_training.py @@ -615,7 +615,7 @@ def do_textual_inversion_training( vae_info = model_manager.get_model(*model_meta, submodel=SubModelType.Vae) unet_info = model_manager.get_model(*model_meta, submodel=SubModelType.UNet) - pipeline_args = dict(local_files_only=True) + pipeline_args = {"local_files_only": True} if tokenizer_name: tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args) else: diff --git a/invokeai/backend/util/hotfixes.py b/invokeai/backend/util/hotfixes.py index fb1297996c..835575c7a1 100644 --- a/invokeai/backend/util/hotfixes.py +++ b/invokeai/backend/util/hotfixes.py @@ -732,7 +732,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): controlnet_down_block_res_samples = () - for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + for down_block_res_sample, controlnet_block in zip( + down_block_res_samples, self.controlnet_down_blocks, strict=True + ): down_block_res_sample = controlnet_block(down_block_res_sample) controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) @@ -745,7 +747,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 scales = scales * conditioning_scale - down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] + down_block_res_samples = [ + sample * scale for sample, scale in zip(down_block_res_samples, scales, strict=True) + ] mid_block_res_sample = mid_block_res_sample * scales[-1] # last one else: down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] diff --git a/invokeai/backend/util/logging.py b/invokeai/backend/util/logging.py index 3c829a1a02..cb9f362d90 100644 --- a/invokeai/backend/util/logging.py +++ b/invokeai/backend/util/logging.py @@ -225,34 +225,34 @@ def basicConfig(**kwargs): _FACILITY_MAP = ( - dict( - LOG_KERN=syslog.LOG_KERN, - LOG_USER=syslog.LOG_USER, - LOG_MAIL=syslog.LOG_MAIL, - LOG_DAEMON=syslog.LOG_DAEMON, - LOG_AUTH=syslog.LOG_AUTH, - LOG_LPR=syslog.LOG_LPR, - LOG_NEWS=syslog.LOG_NEWS, - LOG_UUCP=syslog.LOG_UUCP, - LOG_CRON=syslog.LOG_CRON, - LOG_SYSLOG=syslog.LOG_SYSLOG, - LOG_LOCAL0=syslog.LOG_LOCAL0, - LOG_LOCAL1=syslog.LOG_LOCAL1, - LOG_LOCAL2=syslog.LOG_LOCAL2, - LOG_LOCAL3=syslog.LOG_LOCAL3, - LOG_LOCAL4=syslog.LOG_LOCAL4, - LOG_LOCAL5=syslog.LOG_LOCAL5, - LOG_LOCAL6=syslog.LOG_LOCAL6, - LOG_LOCAL7=syslog.LOG_LOCAL7, - ) + { + "LOG_KERN": syslog.LOG_KERN, + "LOG_USER": syslog.LOG_USER, + "LOG_MAIL": syslog.LOG_MAIL, + "LOG_DAEMON": syslog.LOG_DAEMON, + "LOG_AUTH": syslog.LOG_AUTH, + "LOG_LPR": syslog.LOG_LPR, + "LOG_NEWS": syslog.LOG_NEWS, + "LOG_UUCP": syslog.LOG_UUCP, + "LOG_CRON": syslog.LOG_CRON, + "LOG_SYSLOG": syslog.LOG_SYSLOG, + "LOG_LOCAL0": syslog.LOG_LOCAL0, + "LOG_LOCAL1": syslog.LOG_LOCAL1, + "LOG_LOCAL2": syslog.LOG_LOCAL2, + "LOG_LOCAL3": syslog.LOG_LOCAL3, + "LOG_LOCAL4": syslog.LOG_LOCAL4, + "LOG_LOCAL5": syslog.LOG_LOCAL5, + "LOG_LOCAL6": syslog.LOG_LOCAL6, + "LOG_LOCAL7": syslog.LOG_LOCAL7, + } if SYSLOG_AVAILABLE - else dict() + else {} ) -_SOCK_MAP = dict( - SOCK_STREAM=socket.SOCK_STREAM, - SOCK_DGRAM=socket.SOCK_DGRAM, -) +_SOCK_MAP = { + "SOCK_STREAM": socket.SOCK_STREAM, + "SOCK_DGRAM": socket.SOCK_DGRAM, +} class InvokeAIFormatter(logging.Formatter): @@ -344,7 +344,7 @@ LOG_FORMATTERS = { class InvokeAILogger(object): - loggers = dict() + loggers = {} @classmethod def get_logger( @@ -364,7 +364,7 @@ class InvokeAILogger(object): @classmethod def get_loggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]: handler_strs = config.log_handlers - handlers = list() + handlers = [] for handler in handler_strs: handler_name, *args = handler.split("=", 2) args = args[0] if len(args) > 0 else None @@ -398,7 +398,7 @@ class InvokeAILogger(object): raise ValueError("syslog is not available on this system") if not args: args = "/dev/log" if Path("/dev/log").exists() else "address:localhost:514" - syslog_args = dict() + syslog_args = {} try: for a in args.split(","): arg_name, *arg_value = a.split(":", 2) @@ -434,7 +434,7 @@ class InvokeAILogger(object): path = url.path port = url.port or 80 - syslog_args = dict() + syslog_args = {} for a in arg_list: arg_name, *arg_value = a.split(":", 2) if arg_name == "method": diff --git a/invokeai/backend/util/util.py b/invokeai/backend/util/util.py index 0796f1a8cd..12ba3701cf 100644 --- a/invokeai/backend/util/util.py +++ b/invokeai/backend/util/util.py @@ -26,7 +26,7 @@ def log_txt_as_img(wh, xc, size=10): # wh a tuple of (width, height) # xc a list of captions to plot b = len(xc) - txts = list() + txts = [] for bi in range(b): txt = Image.new("RGB", wh, color="white") draw = ImageDraw.Draw(txt) @@ -90,7 +90,7 @@ def instantiate_from_config(config, **kwargs): elif config == "__is_unconditional__": return None raise KeyError("Expected key `target` to instantiate.") - return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs) + return get_obj_from_str(config["target"])(**config.get("params", {}), **kwargs) def get_obj_from_str(string, reload=False): @@ -228,11 +228,12 @@ def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 angles = 2 * math.pi * rand_val gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1).to(device) - tile_grads = ( - lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]] - .repeat_interleave(d[0], 0) - .repeat_interleave(d[1], 1) - ) + def tile_grads(slice1, slice2): + return ( + gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]] + .repeat_interleave(d[0], 0) + .repeat_interleave(d[1], 1) + ) def dot(grad, shift): return ( diff --git a/invokeai/frontend/install/import_images.py b/invokeai/frontend/install/import_images.py index ec90700bd0..61faa48f9d 100644 --- a/invokeai/frontend/install/import_images.py +++ b/invokeai/frontend/install/import_images.py @@ -341,19 +341,19 @@ class InvokeAIMetadataParser: # this was more elegant as a case statement, but that's not available in python 3.9 if old_scheduler is None: return None - scheduler_map = dict( - ddim="ddim", - plms="pnmd", - k_lms="lms", - k_dpm_2="kdpm_2", - k_dpm_2_a="kdpm_2_a", - dpmpp_2="dpmpp_2s", - k_dpmpp_2="dpmpp_2m", - k_dpmpp_2_a=None, # invalid, in 2.3.x, selecting this sample would just fallback to last run or plms if new session - k_euler="euler", - k_euler_a="euler_a", - k_heun="heun", - ) + scheduler_map = { + "ddim": "ddim", + "plms": "pnmd", + "k_lms": "lms", + "k_dpm_2": "kdpm_2", + "k_dpm_2_a": "kdpm_2_a", + "dpmpp_2": "dpmpp_2s", + "k_dpmpp_2": "dpmpp_2m", + "k_dpmpp_2_a": None, # invalid, in 2.3.x, selecting this sample would just fallback to last run or plms if new session + "k_euler": "euler", + "k_euler_a": "euler_a", + "k_heun": "heun", + } return scheduler_map.get(old_scheduler) def split_prompt(self, raw_prompt: str): diff --git a/invokeai/frontend/install/model_install.py b/invokeai/frontend/install/model_install.py index 0ea2570b2b..e23538ffd6 100644 --- a/invokeai/frontend/install/model_install.py +++ b/invokeai/frontend/install/model_install.py @@ -72,7 +72,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): def __init__(self, parentApp, name, multipage=False, *args, **keywords): self.multipage = multipage self.subprocess = None - super().__init__(parentApp=parentApp, name=name, *args, **keywords) + super().__init__(parentApp=parentApp, name=name, *args, **keywords) # noqa: B026 # TODO: maybe this is bad? def create(self): self.keypress_timeout = 10 @@ -203,14 +203,14 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): ) # This restores the selected page on return from an installation - for i in range(1, self.current_tab + 1): + for _i in range(1, self.current_tab + 1): self.tabs.h_cursor_line_down(1) self._toggle_tables([self.current_tab]) ############# diffusers tab ########## def add_starter_pipelines(self) -> dict[str, npyscreen.widget]: """Add widgets responsible for selecting diffusers models""" - widgets = dict() + widgets = {} models = self.all_models starters = self.starter_models starter_model_labels = self.model_labels @@ -258,10 +258,12 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): model_type: ModelType, window_width: int = 120, install_prompt: str = None, - exclude: set = set(), + exclude: set = None, ) -> dict[str, npyscreen.widget]: """Generic code to create model selection widgets""" - widgets = dict() + if exclude is None: + exclude = set() + widgets = {} model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude] model_labels = [self.model_labels[x] for x in model_list] @@ -366,13 +368,13 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): ] for group in widgets: - for k, v in group.items(): + for _k, v in group.items(): try: v.hidden = True v.editable = False except Exception: pass - for k, v in widgets[selected_tab].items(): + for _k, v in widgets[selected_tab].items(): try: v.hidden = False if not isinstance(v, (npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)): @@ -391,7 +393,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): label_width = max([len(models[x].name) for x in models]) description_width = window_width - label_width - checkbox_width - spacing_width - result = dict() + result = {} for x in models.keys(): description = models[x].description description = ( @@ -433,11 +435,11 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): parent_conn, child_conn = Pipe() p = Process( target=process_and_execute, - kwargs=dict( - opt=app.program_opts, - selections=app.install_selections, - conn_out=child_conn, - ), + kwargs={ + "opt": app.program_opts, + "selections": app.install_selections, + "conn_out": child_conn, + }, ) p.start() child_conn.close() @@ -558,7 +560,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): for section in ui_sections: if "models_selected" not in section: continue - selected = set([section["models"][x] for x in section["models_selected"].value]) + selected = {section["models"][x] for x in section["models_selected"].value} models_to_install = [x for x in selected if not self.all_models[x].installed] models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed] selections.remove_models.extend(models_to_remove) diff --git a/invokeai/frontend/install/widgets.py b/invokeai/frontend/install/widgets.py index 4a37aba9b8..5905ae29da 100644 --- a/invokeai/frontend/install/widgets.py +++ b/invokeai/frontend/install/widgets.py @@ -11,6 +11,7 @@ import sys import textwrap from curses import BUTTON2_CLICKED, BUTTON3_CLICKED from shutil import get_terminal_size +from typing import Optional import npyscreen import npyscreen.wgmultiline as wgmultiline @@ -243,7 +244,9 @@ class SelectColumnBase: class MultiSelectColumns(SelectColumnBase, npyscreen.MultiSelect): - def __init__(self, screen, columns: int = 1, values: list = [], **keywords): + def __init__(self, screen, columns: int = 1, values: Optional[list] = None, **keywords): + if values is None: + values = [] self.columns = columns self.value_cnt = len(values) self.rows = math.ceil(self.value_cnt / self.columns) @@ -267,7 +270,9 @@ class SingleSelectWithChanged(npyscreen.SelectOne): class SingleSelectColumnsSimple(SelectColumnBase, SingleSelectWithChanged): """Row of radio buttons. Spacebar to select.""" - def __init__(self, screen, columns: int = 1, values: list = [], **keywords): + def __init__(self, screen, columns: int = 1, values: list = None, **keywords): + if values is None: + values = [] self.columns = columns self.value_cnt = len(values) self.rows = math.ceil(self.value_cnt / self.columns) diff --git a/invokeai/frontend/merge/merge_diffusers.py b/invokeai/frontend/merge/merge_diffusers.py index d515c5b4ee..92b98b52f9 100644 --- a/invokeai/frontend/merge/merge_diffusers.py +++ b/invokeai/frontend/merge/merge_diffusers.py @@ -275,14 +275,14 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): interp = self.interpolations[self.merge_method.value[0]] bases = ["sd-1", "sd-2", "sdxl"] - args = dict( - model_names=models, - base_model=BaseModelType(bases[self.base_select.value[0]]), - alpha=self.alpha.value, - interp=interp, - force=self.force.value, - merged_model_name=self.merged_model_name.value, - ) + args = { + "model_names": models, + "base_model": BaseModelType(bases[self.base_select.value[0]]), + "alpha": self.alpha.value, + "interp": interp, + "force": self.force.value, + "merged_model_name": self.merged_model_name.value, + } return args def check_for_overwrite(self) -> bool: @@ -297,7 +297,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): def validate_field_values(self) -> bool: bad_fields = [] model_names = self.model_names - selected_models = set((model_names[self.model1.value[0]], model_names[self.model2.value[0]])) + selected_models = {model_names[self.model1.value[0]], model_names[self.model2.value[0]]} if self.model3.value[0] > 0: selected_models.add(model_names[self.model3.value[0] - 1]) if len(selected_models) < 2: diff --git a/invokeai/frontend/training/textual_inversion.py b/invokeai/frontend/training/textual_inversion.py index f3911f7e0e..556f216e97 100755 --- a/invokeai/frontend/training/textual_inversion.py +++ b/invokeai/frontend/training/textual_inversion.py @@ -276,13 +276,13 @@ class textualInversionForm(npyscreen.FormMultiPageAction): def get_model_names(self) -> Tuple[List[str], int]: conf = OmegaConf.load(config.root_dir / "configs/models.yaml") - model_names = [idx for idx in sorted(list(conf.keys())) if conf[idx].get("format", None) == "diffusers"] + model_names = [idx for idx in sorted(conf.keys()) if conf[idx].get("format", None) == "diffusers"] defaults = [idx for idx in range(len(model_names)) if "default" in conf[model_names[idx]]] default = defaults[0] if len(defaults) > 0 else 0 return (model_names, default) def marshall_arguments(self) -> dict: - args = dict() + args = {} # the choices args.update( diff --git a/pyproject.toml b/pyproject.toml index fc9c150e3e..bdf23c4bf5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,10 +35,10 @@ dependencies = [ "accelerate~=0.23.0", "albumentations", "click", - "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", + "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "compel~=2.0.2", "controlnet-aux>=0.0.6", - "timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26 + "timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26 "datasets", "diffusers[torch]~=0.23.0", "dnspython~=2.4.0", @@ -96,10 +96,8 @@ dependencies = [ ] "dev" = ["jurigged", "pudb"] "test" = [ - "black", - "flake8", - "Flake8-pyproject", - "isort", + "ruff", + "ruff-lsp", "mypy", "pre-commit", "pytest>6.0.0", @@ -108,7 +106,7 @@ dependencies = [ ] "xformers" = [ "xformers==0.0.22post7; sys_platform!='darwin'", - "triton; sys_platform=='linux'", + "triton; sys_platform=='linux'", ] "onnx" = ["onnxruntime"] "onnx-cuda" = ["onnxruntime-gpu"] @@ -194,10 +192,16 @@ directory = "coverage/html" output = "coverage/index.xml" #=== End: PyTest and Coverage -[tool.flake8] -max-line-length = 120 -ignore = ["E203", "E266", "E501", "W503"] -select = ["B", "C", "E", "F", "W", "T4"] +#=== Begin: Ruff +[tool.ruff] +line-length = 120 +ignore = [ + "E501", # https://docs.astral.sh/ruff/rules/line-too-long/ + "C901", # https://docs.astral.sh/ruff/rules/complex-structure/ + "B008", # https://docs.astral.sh/ruff/rules/function-call-in-default-argument/ + "B904", # https://docs.astral.sh/ruff/rules/raise-without-from-inside-except/ +] +select = ["B", "C", "E", "F", "W"] exclude = [ ".git", "__pycache__", @@ -206,14 +210,9 @@ exclude = [ "invokeai/frontend/web/node_modules/", ".venv*", ] +#=== End: Ruff -[tool.black] -line-length = 120 - -[tool.isort] -profile = "black" -line_length = 120 - +#=== Begin: MyPy [tool.mypy] ignore_missing_imports = true # ignores missing types in third-party libraries @@ -263,3 +262,4 @@ module = [ "invokeai.backend.util.util", "invokeai.frontend.install.model_install", ] +#=== End: MyPy diff --git a/scripts/configure_invokeai.py b/scripts/configure_invokeai.py index 98e5044511..c994668ea6 100755 --- a/scripts/configure_invokeai.py +++ b/scripts/configure_invokeai.py @@ -6,5 +6,7 @@ import warnings from invokeai.frontend.install.invokeai_configure import invokeai_configure as configure if __name__ == "__main__": - warnings.warn("configure_invokeai.py is deprecated, running 'invokeai-configure'...", DeprecationWarning) + warnings.warn( + "configure_invokeai.py is deprecated, running 'invokeai-configure'...", DeprecationWarning, stacklevel=2 + ) configure() diff --git a/scripts/scan_models_directory.py b/scripts/scan_models_directory.py index 0038023c06..a85fb793dd 100755 --- a/scripts/scan_models_directory.py +++ b/scripts/scan_models_directory.py @@ -37,22 +37,22 @@ def main(): if args.all_models or model_type == "diffusers": for d in dirs: - conf[f"{base}/{model_type}/{d}"] = dict( - path=os.path.join(root, d), - description=f"{model_type} model {d}", - format="folder", - base=base, - ) + conf[f"{base}/{model_type}/{d}"] = { + "path": os.path.join(root, d), + "description": f"{model_type} model {d}", + "format": "folder", + "base": base, + } for f in files: basename = Path(f).stem format = Path(f).suffix[1:] - conf[f"{base}/{model_type}/{basename}"] = dict( - path=os.path.join(root, f), - description=f"{model_type} model {basename}", - format=format, - base=base, - ) + conf[f"{base}/{model_type}/{basename}"] = { + "path": os.path.join(root, f), + "description": f"{model_type} model {basename}", + "format": format, + "base": base, + } OmegaConf.save(config=dict(sorted(conf.items())), f=sys.stdout) diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py index 171cdfdb6f..f518460612 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/nodes/test_graph_execution_state.py @@ -149,8 +149,8 @@ def test_graph_state_expands_iterator(mock_services): invoke_next(g, mock_services) prepared_add_nodes = g.source_prepared_mapping["3"] - results = set([g.results[n].value for n in prepared_add_nodes]) - expected = set([1, 11, 21]) + results = {g.results[n].value for n in prepared_add_nodes} + expected = {1, 11, 21} assert results == expected @@ -229,7 +229,7 @@ def test_graph_executes_depth_first(mock_services): # Because ordering is not guaranteed, we cannot compare results directly. # Instead, we must count the number of results. def get_completed_count(g, id): - ids = [i for i in g.source_prepared_mapping[id]] + ids = list(g.source_prepared_mapping[id]) completed_ids = [i for i in g.executed if i in ids] return len(completed_ids) diff --git a/tests/nodes/test_node_graph.py b/tests/nodes/test_node_graph.py index e2a50e61e5..12a181f392 100644 --- a/tests/nodes/test_node_graph.py +++ b/tests/nodes/test_node_graph.py @@ -471,7 +471,6 @@ def test_graph_gets_subgraph_node(): g = Graph() n1 = GraphInvocation(id="1") n1.graph = Graph() - n1.graph.add_node n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") n1.graph.add_node(n1_1) @@ -503,8 +502,8 @@ def test_graph_expands_subgraph(): g.add_edge(create_edge("1.2", "value", "2", "a")) dg = g.nx_graph_flat() - assert set(dg.nodes) == set(["1.1", "1.2", "2"]) - assert set(dg.edges) == set([("1.1", "1.2"), ("1.2", "2")]) + assert set(dg.nodes) == {"1.1", "1.2", "2"} + assert set(dg.edges) == {("1.1", "1.2"), ("1.2", "2")} def test_graph_subgraph_t2i(): @@ -532,9 +531,7 @@ def test_graph_subgraph_t2i(): # Validate dg = g.nx_graph_flat() - assert set(dg.nodes) == set( - ["1.width", "1.height", "1.seed", "1.3", "1.4", "1.5", "1.6", "1.7", "1.8", "2", "3", "4"] - ) + assert set(dg.nodes) == {"1.width", "1.height", "1.seed", "1.3", "1.4", "1.5", "1.6", "1.7", "1.8", "2", "3", "4"} expected_edges = [(f"1.{e.source.node_id}", f"1.{e.destination.node_id}") for e in lg.graph.edges] expected_edges.extend([("2", "1.width"), ("3", "1.height"), ("1.8", "4")]) print(expected_edges) @@ -546,7 +543,6 @@ def test_graph_fails_to_get_missing_subgraph_node(): g = Graph() n1 = GraphInvocation(id="1") n1.graph = Graph() - n1.graph.add_node n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") n1.graph.add_node(n1_1) @@ -561,7 +557,6 @@ def test_graph_fails_to_enumerate_non_subgraph_node(): g = Graph() n1 = GraphInvocation(id="1") n1.graph = Graph() - n1.graph.add_node n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") n1.graph.add_node(n1_1) diff --git a/tests/nodes/test_nodes.py b/tests/nodes/test_nodes.py index 51b33dd4c7..caafa33591 100644 --- a/tests/nodes/test_nodes.py +++ b/tests/nodes/test_nodes.py @@ -130,7 +130,7 @@ class TestEventService(EventServiceBase): def __init__(self): super().__init__() - self.events = list() + self.events = [] def dispatch(self, event_name: str, payload: Any) -> None: pass diff --git a/tests/nodes/test_session_queue.py b/tests/nodes/test_session_queue.py index cdab5729f8..768b09898d 100644 --- a/tests/nodes/test_session_queue.py +++ b/tests/nodes/test_session_queue.py @@ -169,7 +169,7 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph): NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue]) # should have 3 node field values - assert type(values[0].field_values) is str + assert isinstance(values[0].field_values, str) assert len(NodeFieldValueValidator.validate_json(values[0].field_values)) == 3 # should have batch id and priority