fix(model_cache): don't check model.config in diffusers format

clean-up from recent merge.
This commit is contained in:
Kevin Turner 2022-11-21 16:46:32 -08:00
parent a6a766dfa2
commit f3f6213b97

View File

@ -4,29 +4,28 @@ They are moved between GPU and CPU as necessary. If CPU memory falls
below a preset minimum, the least recently used model will be below a preset minimum, the least recently used model will be
cleared and loaded from disk when next needed. cleared and loaded from disk when next needed.
''' '''
import gc
import hashlib
import io
import os
import sys
import time
import traceback
import warnings import warnings
from pathlib import Path from pathlib import Path
import torch import torch
import os
import io
import time
import gc
import hashlib
import psutil
import sys
import transformers import transformers
import traceback
import textwrap import textwrap
import contextlib import contextlib
from typing import Union from typing import Union
from omegaconf import OmegaConf from omegaconf import OmegaConf
from omegaconf.errors import ConfigAttributeError from omegaconf.errors import ConfigAttributeError
from picklescan.scanner import scan_file_path
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
from ldm.util import instantiate_from_config, ask_user
from ldm.invoke.globals import Globals from ldm.invoke.globals import Globals
from picklescan.scanner import scan_file_path from ldm.util import instantiate_from_config, ask_user
DEFAULT_MAX_MODELS=2 DEFAULT_MAX_MODELS=2
@ -240,6 +239,13 @@ class ModelCache(object):
width = mconfig.width width = mconfig.width
height = mconfig.height height = mconfig.height
if not os.path.isabs(config):
config = os.path.join(Globals.root,config)
if not os.path.isabs(weights):
weights = os.path.normpath(os.path.join(Globals.root,weights))
# scan model
self._scan_model(model_name, weights)
c = OmegaConf.load(config) c = OmegaConf.load(config)
with open(weights, 'rb') as f: with open(weights, 'rb') as f:
weight_bytes = f.read() weight_bytes = f.read()