refiner only has clip2 not clip

This commit is contained in:
Lincoln Stein 2023-07-16 12:36:38 -04:00
parent 0a2964d8c0
commit 6534288b75

View File

@ -21,17 +21,20 @@ class SDXLModelLoaderOutput(BaseInvocationOutput):
unet: UNetField = Field(default=None, description="UNet submodel")
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels (SDXL only)")
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
vae: VaeField = Field(default=None, description="Vae submodel")
# fmt: on
class SDXLRefinerModelLoaderOutput(SDXLModelLoaderOutput):
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
"""SDXL refiner model loader output"""
# fmt: off
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
unet: UNetField = Field(default=None, description="UNet submodel")
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
vae: VaeField = Field(default=None, description="Vae submodel")
# fmt: on
#fmt: on
class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels."""
@ -50,10 +53,6 @@ class SDXLModelLoaderInvocation(BaseInvocation):
},
}
@classmethod
def _output_class(cls):
return SDXLModelLoaderOutput
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
@ -67,7 +66,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
):
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
return self._output_class(
return SDXLModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
@ -125,7 +124,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
),
)
class SDXLRefinerModelLoaderInvocation(SDXLModelLoaderInvocation):
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels."""
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
@ -139,10 +138,61 @@ class SDXLRefinerModelLoaderInvocation(SDXLModelLoaderInvocation):
},
}
@classmethod
def _output_class(cls):
return SDXLRefinerModelLoaderOutput
def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Main
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=model_name,
base_model=base_model,
model_type=model_type,
):
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
return SDXLRefinerModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip2=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Vae,
),
),
)
# Text to image
class SDXLTextToLatentsInvocation(BaseInvocation):
"""Generates latents from conditionings."""