pass original_config_file to load_single_file()

This commit is contained in:
Lincoln Stein 2024-08-13 21:12:40 -04:00
parent d5c9f4e47f
commit b5ec04f10c
3 changed files with 9 additions and 6 deletions

View File

@ -118,13 +118,16 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
# Some weights of the model checkpoint were not used when initializing CLIPTextModelWithProjection:
# ['text_model.embeddings.position_ids']
original_config_file = self._app_config.legacy_conf_path / config.config_path
with SilenceWarnings():
pipeline = load_class.from_single_file(
config.path,
original_config_file=original_config_file,
torch_dtype=self._torch_dtype,
prediction_type=prediction_type,
upcast_attention=upcast_attention,
load_safety_checker=False,
kwargs={"load_safety_checker": False},
)
if not submodel_type:

View File

@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import diffusers
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalControlNetMixin
from diffusers.loaders import FromOriginalModelMixin
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
from diffusers.models.embeddings import (
@ -32,7 +32,7 @@ from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger(__name__)
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
"""
A ControlNet model.

View File

@ -33,11 +33,11 @@ classifiers = [
]
dependencies = [
# Core generation dependencies, pinned for reproducible builds.
"accelerate==0.30.1",
"accelerate==0.31.0",
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel==2.0.2",
"controlnet-aux==0.0.7",
"diffusers[torch]==0.27.2",
"diffusers[torch]==0.30.0",
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
"mediapipe==0.10.7", # needed for "mediapipeface" controlnet model
"numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()
@ -57,7 +57,7 @@ dependencies = [
# Core application dependencies, pinned for reproducible builds.
"fastapi-events==0.11.1",
"fastapi==0.111.0",
"huggingface-hub==0.23.1",
"huggingface-hub==0.23.5",
"pydantic-settings==2.2.1",
"pydantic==2.7.2",
"python-socketio==5.11.1",