mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refiner only has clip2 not clip
This commit is contained in:
parent
0a2964d8c0
commit
6534288b75
@ -21,16 +21,19 @@ class SDXLModelLoaderOutput(BaseInvocationOutput):
|
||||
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels (SDXL only)")
|
||||
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
# fmt: on
|
||||
|
||||
class SDXLRefinerModelLoaderOutput(SDXLModelLoaderOutput):
|
||||
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL refiner model loader output"""
|
||||
# fmt: off
|
||||
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
# fmt: on
|
||||
#fmt: on
|
||||
|
||||
|
||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl base model, outputting its submodels."""
|
||||
@ -50,10 +53,6 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _output_class(cls):
|
||||
return SDXLModelLoaderOutput
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
@ -67,7 +66,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
):
|
||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||
|
||||
return self._output_class(
|
||||
return SDXLModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=model_name,
|
||||
@ -125,7 +124,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
),
|
||||
)
|
||||
|
||||
class SDXLRefinerModelLoaderInvocation(SDXLModelLoaderInvocation):
|
||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
|
||||
|
||||
@ -139,9 +138,60 @@ class SDXLRefinerModelLoaderInvocation(SDXLModelLoaderInvocation):
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _output_class(cls):
|
||||
return SDXLRefinerModelLoaderOutput
|
||||
def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.Main
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||
|
||||
return SDXLRefinerModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip2=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Tokenizer2,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.TextEncoder2,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Vae,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# Text to image
|
||||
class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
|
Loading…
Reference in New Issue
Block a user