Revert "restore 3.9 compatibility by replacing | with Union[]"

This reverts commit 76bafeb99e.
This commit is contained in:
Lincoln Stein
2023-07-03 10:56:41 -04:00
parent 73a27918c6
commit 2465c7987b
16 changed files with 37 additions and 43 deletions

View File

@ -7,7 +7,7 @@ import secrets
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
from pydantic import Field
from pydantic import BaseModel, Field
import einops
import PIL.Image
@ -17,11 +17,12 @@ import psutil
import torch
import torchvision.transforms as T
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
StableDiffusionPipeline,
)
from diffusers.pipelines.controlnet import MultiControlNetModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
StableDiffusionImg2ImgPipeline,
@ -45,7 +46,7 @@ from .diffusion import (
InvokeAIDiffuserComponent,
PostprocessingSettings,
)
from .offloading import FullyLoadedModelGroup, ModelGroup
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
@dataclass
class PipelineIntermediateState:
@ -104,7 +105,7 @@ class AddsMaskGuidance:
_debug: Optional[Callable] = None
def __call__(
self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning
self, step_output: BaseOutput | SchedulerOutput, t: torch.Tensor, conditioning
) -> BaseOutput:
output_class = step_output.__class__ # We'll create a new one with masked data.

View File

@ -4,7 +4,7 @@ import warnings
import weakref
from abc import ABCMeta, abstractmethod
from collections.abc import MutableMapping
from typing import Callable, Union
from typing import Callable
import torch
from accelerate.utils import send_to_device
@ -117,7 +117,7 @@ class LazilyLoadedModelGroup(ModelGroup):
"""
_hooks: MutableMapping[torch.nn.Module, RemovableHandle]
_current_model_ref: Callable[[], Union[torch.nn.Module, _NoModel]]
_current_model_ref: Callable[[], torch.nn.Module | _NoModel]
def __init__(self, execution_device: torch.device):
super().__init__(execution_device)