mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Tidy SilenceWarnings
context manager (#6493)
## Summary No functional changes, just cleaning some things up as I touch the code. This PR cleans up the `SilenceWarnings` context manager: - Fix type errors - Enable SilenceWarnings to be used as both a context manager and a decorator - Remove duplicate implementation - Check the initial verbosity on `__enter__()` rather than `__init__()` - Save an indentation level in DenoiseLatents ## QA Instructions I generated an image to confirm that warnings are still muted. ## Merge Plan - [x] ⚠️ Merge https://github.com/invoke-ai/InvokeAI/pull/6492 first, then change the target branch to `main`. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [x] _Documentation added / updated (if applicable)_
This commit is contained in:
commit
7e9a89f8c6
@ -16,7 +16,9 @@ from pydantic import field_validator
|
|||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
from transformers import CLIPVisionModelWithProjection
|
from transformers import CLIPVisionModelWithProjection
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
||||||
|
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||||
from invokeai.app.invocations.fields import (
|
from invokeai.app.invocations.fields import (
|
||||||
ConditioningField,
|
ConditioningField,
|
||||||
DenoiseMaskField,
|
DenoiseMaskField,
|
||||||
@ -27,6 +29,7 @@ from invokeai.app.invocations.fields import (
|
|||||||
UIType,
|
UIType,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||||
|
from invokeai.app.invocations.model import ModelIdentifierField, UNetField
|
||||||
from invokeai.app.invocations.primitives import LatentsOutput
|
from invokeai.app.invocations.primitives import LatentsOutput
|
||||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
@ -36,6 +39,11 @@ from invokeai.backend.lora import LoRAModelRaw
|
|||||||
from invokeai.backend.model_manager import BaseModelType
|
from invokeai.backend.model_manager import BaseModelType
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||||
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||||
|
ControlNetData,
|
||||||
|
StableDiffusionGeneratorPipeline,
|
||||||
|
T2IAdapterData,
|
||||||
|
)
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
BasicConditioningInfo,
|
BasicConditioningInfo,
|
||||||
IPAdapterConditioningInfo,
|
IPAdapterConditioningInfo,
|
||||||
@ -45,20 +53,11 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
TextConditioningData,
|
TextConditioningData,
|
||||||
TextConditioningRegions,
|
TextConditioningRegions,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.mask import to_standard_float_mask
|
from invokeai.backend.util.mask import to_standard_float_mask
|
||||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
|
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
|
||||||
ControlNetData,
|
|
||||||
StableDiffusionGeneratorPipeline,
|
|
||||||
T2IAdapterData,
|
|
||||||
)
|
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
|
||||||
from ...backend.util.devices import TorchDevice
|
|
||||||
from .baseinvocation import BaseInvocation, invocation
|
|
||||||
from .controlnet_image_processors import ControlField
|
|
||||||
from .model import ModelIdentifierField, UNetField
|
|
||||||
|
|
||||||
|
|
||||||
def get_scheduler(
|
def get_scheduler(
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
@ -658,155 +657,155 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
with SilenceWarnings(): # this quenches NSFW nag from diffusers
|
seed = None
|
||||||
seed = None
|
noise = None
|
||||||
noise = None
|
if self.noise is not None:
|
||||||
if self.noise is not None:
|
noise = context.tensors.load(self.noise.latents_name)
|
||||||
noise = context.tensors.load(self.noise.latents_name)
|
seed = self.noise.seed
|
||||||
seed = self.noise.seed
|
|
||||||
|
|
||||||
if self.latents is not None:
|
|
||||||
latents = context.tensors.load(self.latents.latents_name)
|
|
||||||
if seed is None:
|
|
||||||
seed = self.latents.seed
|
|
||||||
|
|
||||||
if noise is not None and noise.shape[1:] != latents.shape[1:]:
|
|
||||||
raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
|
|
||||||
|
|
||||||
elif noise is not None:
|
|
||||||
latents = torch.zeros_like(noise)
|
|
||||||
else:
|
|
||||||
raise Exception("'latents' or 'noise' must be provided!")
|
|
||||||
|
|
||||||
|
if self.latents is not None:
|
||||||
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
if seed is None:
|
if seed is None:
|
||||||
seed = 0
|
seed = self.latents.seed
|
||||||
|
|
||||||
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
if noise is not None and noise.shape[1:] != latents.shape[1:]:
|
||||||
|
raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
|
||||||
|
|
||||||
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
elif noise is not None:
|
||||||
# below. Investigate whether this is appropriate.
|
latents = torch.zeros_like(noise)
|
||||||
t2i_adapter_data = self.run_t2i_adapters(
|
else:
|
||||||
context,
|
raise Exception("'latents' or 'noise' must be provided!")
|
||||||
self.t2i_adapter,
|
|
||||||
latents.shape,
|
if seed is None:
|
||||||
do_classifier_free_guidance=True,
|
seed = 0
|
||||||
|
|
||||||
|
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||||
|
|
||||||
|
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
||||||
|
# below. Investigate whether this is appropriate.
|
||||||
|
t2i_adapter_data = self.run_t2i_adapters(
|
||||||
|
context,
|
||||||
|
self.t2i_adapter,
|
||||||
|
latents.shape,
|
||||||
|
do_classifier_free_guidance=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
ip_adapters: List[IPAdapterField] = []
|
||||||
|
if self.ip_adapter is not None:
|
||||||
|
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
|
||||||
|
if isinstance(self.ip_adapter, list):
|
||||||
|
ip_adapters = self.ip_adapter
|
||||||
|
else:
|
||||||
|
ip_adapters = [self.ip_adapter]
|
||||||
|
|
||||||
|
# If there are IP adapters, the following line runs the adapters' CLIPVision image encoders to return
|
||||||
|
# a series of image conditioning embeddings. This is being done here rather than in the
|
||||||
|
# big model context below in order to use less VRAM on low-VRAM systems.
|
||||||
|
# The image prompts are then passed to prep_ip_adapter_data().
|
||||||
|
image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters)
|
||||||
|
|
||||||
|
# get the unet's config so that we can pass the base to dispatch_progress()
|
||||||
|
unet_config = context.models.get_config(self.unet.unet.key)
|
||||||
|
|
||||||
|
def step_callback(state: PipelineIntermediateState) -> None:
|
||||||
|
context.util.sd_step_callback(state, unet_config.base)
|
||||||
|
|
||||||
|
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||||
|
for lora in self.unet.loras:
|
||||||
|
lora_info = context.models.load(lora.lora)
|
||||||
|
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||||
|
yield (lora_info.model, lora.weight)
|
||||||
|
del lora_info
|
||||||
|
return
|
||||||
|
|
||||||
|
unet_info = context.models.load(self.unet.unet)
|
||||||
|
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||||
|
with (
|
||||||
|
ExitStack() as exit_stack,
|
||||||
|
unet_info.model_on_device() as (model_state_dict, unet),
|
||||||
|
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||||
|
set_seamless(unet, self.unet.seamless_axes), # FIXME
|
||||||
|
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||||
|
ModelPatcher.apply_lora_unet(
|
||||||
|
unet,
|
||||||
|
loras=_lora_loader(),
|
||||||
|
model_state_dict=model_state_dict,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
assert isinstance(unet, UNet2DConditionModel)
|
||||||
|
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
if noise is not None:
|
||||||
|
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
if masked_latents is not None:
|
||||||
|
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|
||||||
|
scheduler = get_scheduler(
|
||||||
|
context=context,
|
||||||
|
scheduler_info=self.unet.scheduler,
|
||||||
|
scheduler_name=self.scheduler,
|
||||||
|
seed=seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
ip_adapters: List[IPAdapterField] = []
|
pipeline = self.create_pipeline(unet, scheduler)
|
||||||
if self.ip_adapter is not None:
|
|
||||||
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
|
|
||||||
if isinstance(self.ip_adapter, list):
|
|
||||||
ip_adapters = self.ip_adapter
|
|
||||||
else:
|
|
||||||
ip_adapters = [self.ip_adapter]
|
|
||||||
|
|
||||||
# If there are IP adapters, the following line runs the adapters' CLIPVision image encoders to return
|
_, _, latent_height, latent_width = latents.shape
|
||||||
# a series of image conditioning embeddings. This is being done here rather than in the
|
conditioning_data = self.get_conditioning_data(
|
||||||
# big model context below in order to use less VRAM on low-VRAM systems.
|
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
|
||||||
# The image prompts are then passed to prep_ip_adapter_data().
|
)
|
||||||
image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters)
|
|
||||||
|
|
||||||
# get the unet's config so that we can pass the base to dispatch_progress()
|
controlnet_data = self.prep_control_data(
|
||||||
unet_config = context.models.get_config(self.unet.unet.key)
|
context=context,
|
||||||
|
control_input=self.control,
|
||||||
|
latents_shape=latents.shape,
|
||||||
|
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||||
|
do_classifier_free_guidance=True,
|
||||||
|
exit_stack=exit_stack,
|
||||||
|
)
|
||||||
|
|
||||||
def step_callback(state: PipelineIntermediateState) -> None:
|
ip_adapter_data = self.prep_ip_adapter_data(
|
||||||
context.util.sd_step_callback(state, unet_config.base)
|
context=context,
|
||||||
|
ip_adapters=ip_adapters,
|
||||||
|
image_prompts=image_prompts,
|
||||||
|
exit_stack=exit_stack,
|
||||||
|
latent_height=latent_height,
|
||||||
|
latent_width=latent_width,
|
||||||
|
dtype=unet.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||||
for lora in self.unet.loras:
|
scheduler,
|
||||||
lora_info = context.models.load(lora.lora)
|
device=unet.device,
|
||||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
steps=self.steps,
|
||||||
yield (lora_info.model, lora.weight)
|
denoising_start=self.denoising_start,
|
||||||
del lora_info
|
denoising_end=self.denoising_end,
|
||||||
return
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
unet_info = context.models.load(self.unet.unet)
|
result_latents = pipeline.latents_from_embeddings(
|
||||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
latents=latents,
|
||||||
with (
|
timesteps=timesteps,
|
||||||
ExitStack() as exit_stack,
|
init_timestep=init_timestep,
|
||||||
unet_info.model_on_device() as (model_state_dict, unet),
|
noise=noise,
|
||||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
seed=seed,
|
||||||
set_seamless(unet, self.unet.seamless_axes), # FIXME
|
mask=mask,
|
||||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
masked_latents=masked_latents,
|
||||||
ModelPatcher.apply_lora_unet(
|
gradient_mask=gradient_mask,
|
||||||
unet,
|
num_inference_steps=num_inference_steps,
|
||||||
loras=_lora_loader(),
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
model_state_dict=model_state_dict,
|
conditioning_data=conditioning_data,
|
||||||
),
|
control_data=controlnet_data,
|
||||||
):
|
ip_adapter_data=ip_adapter_data,
|
||||||
assert isinstance(unet, UNet2DConditionModel)
|
t2i_adapter_data=t2i_adapter_data,
|
||||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
callback=step_callback,
|
||||||
if noise is not None:
|
)
|
||||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
|
||||||
if mask is not None:
|
|
||||||
mask = mask.to(device=unet.device, dtype=unet.dtype)
|
|
||||||
if masked_latents is not None:
|
|
||||||
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
|
|
||||||
|
|
||||||
scheduler = get_scheduler(
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
context=context,
|
result_latents = result_latents.to("cpu")
|
||||||
scheduler_info=self.unet.scheduler,
|
TorchDevice.empty_cache()
|
||||||
scheduler_name=self.scheduler,
|
|
||||||
seed=seed,
|
|
||||||
)
|
|
||||||
|
|
||||||
pipeline = self.create_pipeline(unet, scheduler)
|
name = context.tensors.save(tensor=result_latents)
|
||||||
|
|
||||||
_, _, latent_height, latent_width = latents.shape
|
|
||||||
conditioning_data = self.get_conditioning_data(
|
|
||||||
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
|
|
||||||
)
|
|
||||||
|
|
||||||
controlnet_data = self.prep_control_data(
|
|
||||||
context=context,
|
|
||||||
control_input=self.control,
|
|
||||||
latents_shape=latents.shape,
|
|
||||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
|
||||||
do_classifier_free_guidance=True,
|
|
||||||
exit_stack=exit_stack,
|
|
||||||
)
|
|
||||||
|
|
||||||
ip_adapter_data = self.prep_ip_adapter_data(
|
|
||||||
context=context,
|
|
||||||
ip_adapters=ip_adapters,
|
|
||||||
image_prompts=image_prompts,
|
|
||||||
exit_stack=exit_stack,
|
|
||||||
latent_height=latent_height,
|
|
||||||
latent_width=latent_width,
|
|
||||||
dtype=unet.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
|
||||||
scheduler,
|
|
||||||
device=unet.device,
|
|
||||||
steps=self.steps,
|
|
||||||
denoising_start=self.denoising_start,
|
|
||||||
denoising_end=self.denoising_end,
|
|
||||||
seed=seed,
|
|
||||||
)
|
|
||||||
|
|
||||||
result_latents = pipeline.latents_from_embeddings(
|
|
||||||
latents=latents,
|
|
||||||
timesteps=timesteps,
|
|
||||||
init_timestep=init_timestep,
|
|
||||||
noise=noise,
|
|
||||||
seed=seed,
|
|
||||||
mask=mask,
|
|
||||||
masked_latents=masked_latents,
|
|
||||||
gradient_mask=gradient_mask,
|
|
||||||
num_inference_steps=num_inference_steps,
|
|
||||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
|
||||||
conditioning_data=conditioning_data,
|
|
||||||
control_data=controlnet_data,
|
|
||||||
ip_adapter_data=ip_adapter_data,
|
|
||||||
t2i_adapter_data=t2i_adapter_data,
|
|
||||||
callback=step_callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
|
||||||
result_latents = result_latents.to("cpu")
|
|
||||||
TorchDevice.empty_cache()
|
|
||||||
|
|
||||||
name = context.tensors.save(tensor=result_latents)
|
|
||||||
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
||||||
|
@ -10,7 +10,7 @@ from picklescan.scanner import scan_file_path
|
|||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.util.misc import uuid_string
|
from invokeai.app.util.misc import uuid_string
|
||||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||||
from invokeai.backend.util.util import SilenceWarnings
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
|
|
||||||
from .config import (
|
from .config import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
|
@ -1,29 +1,36 @@
|
|||||||
"""Context class to silence transformers and diffusers warnings."""
|
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any
|
from contextlib import ContextDecorator
|
||||||
|
|
||||||
from diffusers import logging as diffusers_logging
|
from diffusers.utils import logging as diffusers_logging
|
||||||
from transformers import logging as transformers_logging
|
from transformers import logging as transformers_logging
|
||||||
|
|
||||||
|
|
||||||
class SilenceWarnings(object):
|
# Inherit from ContextDecorator to allow using SilenceWarnings as both a context manager and a decorator.
|
||||||
"""Use in context to temporarily turn off warnings from transformers & diffusers modules.
|
class SilenceWarnings(ContextDecorator):
|
||||||
|
"""A context manager that disables warnings from transformers & diffusers modules while active.
|
||||||
|
|
||||||
|
As context manager:
|
||||||
|
```
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
# do something
|
# do something
|
||||||
|
```
|
||||||
|
|
||||||
|
As decorator:
|
||||||
|
```
|
||||||
|
@SilenceWarnings()
|
||||||
|
def some_function():
|
||||||
|
# do something
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.transformers_verbosity = transformers_logging.get_verbosity()
|
|
||||||
self.diffusers_verbosity = diffusers_logging.get_verbosity()
|
|
||||||
|
|
||||||
def __enter__(self) -> None:
|
def __enter__(self) -> None:
|
||||||
|
self._transformers_verbosity = transformers_logging.get_verbosity()
|
||||||
|
self._diffusers_verbosity = diffusers_logging.get_verbosity()
|
||||||
transformers_logging.set_verbosity_error()
|
transformers_logging.set_verbosity_error()
|
||||||
diffusers_logging.set_verbosity_error()
|
diffusers_logging.set_verbosity_error()
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
|
|
||||||
def __exit__(self, *args: Any) -> None:
|
def __exit__(self, *args) -> None:
|
||||||
transformers_logging.set_verbosity(self.transformers_verbosity)
|
transformers_logging.set_verbosity(self._transformers_verbosity)
|
||||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
diffusers_logging.set_verbosity(self._diffusers_verbosity)
|
||||||
warnings.simplefilter("default")
|
warnings.simplefilter("default")
|
||||||
|
@ -3,12 +3,9 @@ import io
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import unicodedata
|
import unicodedata
|
||||||
import warnings
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from diffusers import logging as diffusers_logging
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import logging as transformers_logging
|
|
||||||
|
|
||||||
# actual size of a gig
|
# actual size of a gig
|
||||||
GIG = 1073741824
|
GIG = 1073741824
|
||||||
@ -80,21 +77,3 @@ class Chdir(object):
|
|||||||
|
|
||||||
def __exit__(self, *args):
|
def __exit__(self, *args):
|
||||||
os.chdir(self.original)
|
os.chdir(self.original)
|
||||||
|
|
||||||
|
|
||||||
class SilenceWarnings(object):
|
|
||||||
"""Context manager to temporarily lower verbosity of diffusers & transformers warning messages."""
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
"""Set verbosity to error."""
|
|
||||||
self.transformers_verbosity = transformers_logging.get_verbosity()
|
|
||||||
self.diffusers_verbosity = diffusers_logging.get_verbosity()
|
|
||||||
transformers_logging.set_verbosity_error()
|
|
||||||
diffusers_logging.set_verbosity_error()
|
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
|
|
||||||
def __exit__(self, type, value, traceback):
|
|
||||||
"""Restore logger verbosity to state before context was entered."""
|
|
||||||
transformers_logging.set_verbosity(self.transformers_verbosity)
|
|
||||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
|
||||||
warnings.simplefilter("default")
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user