quench safety checker warnings from diffusers

This commit is contained in:
Lincoln Stein 2023-02-03 10:14:51 -05:00
parent 9e46badc40
commit 9ae55c91cc
2 changed files with 21 additions and 14 deletions

View File

@ -45,11 +45,11 @@ from diffusers import (
PNDMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
logging as dlogging,
)
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
from diffusers.utils import is_safetensors_available
from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline

View File

@ -8,12 +8,13 @@ import argparse
import os
import sys
import traceback
import warnings
from argparse import Namespace
from pathlib import Path
from typing import List, Union
import npyscreen
from diffusers import DiffusionPipeline
from diffusers import DiffusionPipeline, logging as dlogging
from omegaconf import OmegaConf
from ldm.invoke.globals import (
@ -46,18 +47,24 @@ def merge_diffusion_models(
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
pipe = DiffusionPipeline.from_pretrained(
model_ids_or_paths[0],
cache_dir=kwargs.get("cache_dir", global_cache_dir()),
custom_pipeline="checkpoint_merger",
)
merged_pipe = pipe.merge(
pretrained_model_name_or_path_list=model_ids_or_paths,
alpha=alpha,
interp=interp,
force=force,
**kwargs,
)
with warnings.catch_warnings():
warnings.simplefilter('ignore')
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
pipe = DiffusionPipeline.from_pretrained(
model_ids_or_paths[0],
cache_dir=kwargs.get("cache_dir", global_cache_dir()),
custom_pipeline="checkpoint_merger",
)
merged_pipe = pipe.merge(
pretrained_model_name_or_path_list=model_ids_or_paths,
alpha=alpha,
interp=interp,
force=force,
**kwargs,
)
dlogging.set_verbosity(verbosity)
return merged_pipe