remove redundant prediction_type and attention_upscaling flags

This commit is contained in:
Lincoln Stein 2023-06-23 16:54:52 -04:00
parent 466ec3ab5e
commit 539d1f3bde
3 changed files with 6 additions and 27 deletions

View File

@ -631,8 +631,8 @@ def initialize_rootdir(root: Path, yes_to_all: bool = False):
} }
) )
) )
with open(root / 'invokeai.yaml','w') as f: # with open(root / 'invokeai.yaml','w') as f:
f.write('#empty invokeai.yaml initialization file') # f.write('#empty invokeai.yaml initialization file')
# ------------------------------------- # -------------------------------------
def run_console_ui( def run_console_ui(

View File

@ -3,8 +3,6 @@ Utility (backend) functions used by model_install.py
""" """
import os import os
import shutil import shutil
import sys
import traceback
import warnings import warnings
from dataclasses import dataclass,field from dataclasses import dataclass,field
from pathlib import Path from pathlib import Path
@ -12,10 +10,9 @@ from tempfile import TemporaryDirectory
from typing import List, Dict, Callable, Union, Set from typing import List, Dict, Callable, Union, Set
import requests import requests
from diffusers import AutoencoderKL, StableDiffusionPipeline from diffusers import StableDiffusionPipeline
from huggingface_hub import hf_hub_url, HfFolder, HfApi from huggingface_hub import hf_hub_url, HfFolder, HfApi
from omegaconf import OmegaConf from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from tqdm import tqdm from tqdm import tqdm
import invokeai.configs as configs import invokeai.configs as configs
@ -24,7 +21,6 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
from invokeai.backend.util import download_with_resume from invokeai.backend.util import download_with_resume
from ..stable_diffusion import StableDiffusionGeneratorPipeline
from ..util.logging import InvokeAILogger from ..util.logging import InvokeAILogger
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
@ -290,7 +286,7 @@ class ModelInstall(object):
location = self._download_hf_model(repo_id, files, staging) location = self._download_hf_model(repo_id, files, staging)
break break
elif f'learned_embeds.{suffix}' in files: elif f'learned_embeds.{suffix}' in files:
location = self._download_hf_model(repo_id, [f'learned_embeds.suffix'], staging) location = self._download_hf_model(repo_id, ['learned_embeds.suffix'], staging)
break break
if not location: if not location:
logger.warning(f'Could not determine type of repo {repo_id}. Skipping install.') logger.warning(f'Could not determine type of repo {repo_id}. Skipping install.')
@ -307,7 +303,6 @@ class ModelInstall(object):
self._install_path(dest, info) self._install_path(dest, info)
def _make_attributes(self, path: Path, info: ModelProbeInfo)->dict: def _make_attributes(self, path: Path, info: ModelProbeInfo)->dict:
# convoluted way to retrieve the description from datasets # convoluted way to retrieve the description from datasets
description = f'{info.base_type.value} {info.model_type.value} model' description = f'{info.base_type.value} {info.model_type.value} model'
if key := self.reverse_paths.get(self.current_id): if key := self.reverse_paths.get(self.current_id):
@ -320,18 +315,7 @@ class ModelInstall(object):
model_format = info.format, model_format = info.format,
) )
if info.model_type == ModelType.Pipeline: if info.model_type == ModelType.Pipeline:
attributes.update( attributes.update(dict(variant = info.variant_type,))
dict(
variant = info.variant_type,
)
)
if info.base_type == BaseModelType.StableDiffusion2:
attributes.update(
dict(
prediction_type = info.prediction_type,
upcast_attention = info.prediction_type == SchedulerPredictionType.VPrediction,
)
)
if info.format=="checkpoint": if info.format=="checkpoint":
try: try:
legacy_conf = LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type] if info.base_type == BaseModelType.StableDiffusion2 \ legacy_conf = LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type] if info.base_type == BaseModelType.StableDiffusion2 \

View File

@ -131,17 +131,12 @@ class StableDiffusion2Model(DiffusersModel):
model_format: Literal[StableDiffusion2ModelFormat.Diffusers] model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
variant: ModelVariantType variant: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
class CheckpointConfig(ModelConfigBase): class CheckpointConfig(ModelConfigBase):
model_format: Literal[StableDiffusion2ModelFormat.Checkpoint] model_format: Literal[StableDiffusion2ModelFormat.Checkpoint]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
config: Optional[str] = Field(None) config: str
variant: ModelVariantType variant: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion2 assert base_model == BaseModelType.StableDiffusion2