blackified

This commit is contained in:
Lincoln Stein 2023-07-29 17:30:54 -04:00
parent e82eb0b9fc
commit 348bee8981
3 changed files with 12 additions and 10 deletions

View File

@ -6,8 +6,7 @@ from pydantic import Field
from invokeai.app.invocations.prompt import PromptOutput
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .math import FloatOutput, IntOutput
# Pass-through parameter nodes - used by subgraphs
@ -68,6 +67,7 @@ class ParamStringInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> StringOutput:
return StringOutput(text=self.text)
class ParamPromptInvocation(BaseInvocation):
"""A prompt input parameter"""
@ -80,4 +80,4 @@ class ParamPromptInvocation(BaseInvocation):
}
def invoke(self, context: InvocationContext) -> PromptOutput:
return PromptOutput(prompt=self.prompt)
return PromptOutput(prompt=self.prompt)

View File

@ -1252,7 +1252,7 @@ def download_from_original_stable_diffusion_ckpt(
checkpoint = checkpoint["state_dict"]
logger.debug(f"model_type = {model_type}; original_config_file = {original_config_file}")
precision_probing_key = "model.diffusion_model.input_blocks.0.0.bias"
logger.debug(f"original checkpoint precision == {checkpoint[precision_probing_key].dtype}")
precision = precision or checkpoint[precision_probing_key].dtype
@ -1284,9 +1284,9 @@ def download_from_original_stable_diffusion_ckpt(
original_config_file = BytesIO(requests.get(config_url).content)
original_config = OmegaConf.load(original_config_file)
if original_config['model']['params'].get('use_ema') is not None:
extract_ema = original_config['model']['params']['use_ema']
if original_config["model"]["params"].get("use_ema") is not None:
extract_ema = original_config["model"]["params"]["use_ema"]
if (
model_version == BaseModelType.StableDiffusion2
and original_config["model"]["params"].get("parameterization") == "v"
@ -1689,9 +1689,9 @@ def download_controlnet_from_original_ckpt(
checkpoint = checkpoint["state_dict"]
# use original precision
precision_probing_key = 'input_blocks.0.0.bias'
precision_probing_key = "input_blocks.0.0.bias"
ckpt_precision = checkpoint[precision_probing_key].dtype
logger.debug(f'original controlnet precision = {ckpt_precision}')
logger.debug(f"original controlnet precision = {ckpt_precision}")
precision = precision or ckpt_precision
original_config = OmegaConf.load(original_config_file)

View File

@ -19,6 +19,7 @@ from .base import (
from invokeai.app.services.config import InvokeAIAppConfig
import invokeai.backend.util.logging as logger
class ControlNetModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
@ -123,6 +124,7 @@ class ControlNetModel(ModelBase):
else:
return model_path
def _convert_controlnet_ckpt_and_cache(
model_path: str,
output_path: str,
@ -137,7 +139,7 @@ def _convert_controlnet_ckpt_and_cache(
app_config = InvokeAIAppConfig.get_config()
weights = app_config.root_path / model_path
output_path = Path(output_path)
logger.info(f"Converting {weights} to diffusers format")
# return cached version if it exists
if output_path.exists():