quench safety checker warnings from diffusers

This commit is contained in:
Lincoln Stein 2023-02-03 10:14:51 -05:00
parent 9e46badc40
commit 9ae55c91cc
2 changed files with 21 additions and 14 deletions

View File

@ -45,11 +45,11 @@ from diffusers import (
PNDMScheduler, PNDMScheduler,
StableDiffusionPipeline, StableDiffusionPipeline,
UNet2DConditionModel, UNet2DConditionModel,
logging as dlogging,
) )
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
from diffusers.utils import is_safetensors_available from diffusers.utils import is_safetensors_available
from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline

View File

@ -8,12 +8,13 @@ import argparse
import os import os
import sys import sys
import traceback import traceback
import warnings
from argparse import Namespace from argparse import Namespace
from pathlib import Path from pathlib import Path
from typing import List, Union from typing import List, Union
import npyscreen import npyscreen
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline, logging as dlogging
from omegaconf import OmegaConf from omegaconf import OmegaConf
from ldm.invoke.globals import ( from ldm.invoke.globals import (
@ -46,6 +47,11 @@ def merge_diffusion_models(
**kwargs - the default DiffusionPipeline.get_config_dict kwargs: **kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
""" """
with warnings.catch_warnings():
warnings.simplefilter('ignore')
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
pipe = DiffusionPipeline.from_pretrained( pipe = DiffusionPipeline.from_pretrained(
model_ids_or_paths[0], model_ids_or_paths[0],
cache_dir=kwargs.get("cache_dir", global_cache_dir()), cache_dir=kwargs.get("cache_dir", global_cache_dir()),
@ -58,6 +64,7 @@ def merge_diffusion_models(
force=force, force=force,
**kwargs, **kwargs,
) )
dlogging.set_verbosity(verbosity)
return merged_pipe return merged_pipe