Updated to fix Annotated pydantic errors on modelInfo

This commit is contained in:
Kent Keirsey 2023-03-28 10:39:20 -04:00
parent f53b125caa
commit 5860b517a7

View File

@ -15,11 +15,9 @@ class VaeRepo(BaseModel):
path: Optional[str] = Field(description="The path to the VAE") path: Optional[str] = Field(description="The path to the VAE")
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE") subfolder: Optional[str] = Field(description="The subfolder to use for this VAE")
class ModelInfo(BaseModel): class ModelInfo(BaseModel):
description: Optional[str] = Field(description="A description of the model") description: Optional[str] = Field(description="A description of the model")
class CkptModelInfo(ModelInfo): class CkptModelInfo(ModelInfo):
format: Literal['ckpt'] = 'ckpt' format: Literal['ckpt'] = 'ckpt'
@ -29,7 +27,6 @@ class CkptModelInfo(ModelInfo):
width: Optional[int] = Field(description="The width of the model") width: Optional[int] = Field(description="The width of the model")
height: Optional[int] = Field(description="The height of the model") height: Optional[int] = Field(description="The height of the model")
class DiffusersModelInfo(ModelInfo): class DiffusersModelInfo(ModelInfo):
format: Literal['diffusers'] = 'diffusers' format: Literal['diffusers'] = 'diffusers'
@ -37,14 +34,17 @@ class DiffusersModelInfo(ModelInfo):
repo_id: Optional[str] = Field(description="The repo ID to use for this model") repo_id: Optional[str] = Field(description="The repo ID to use for this model")
path: Optional[str] = Field(description="The path to the model") path: Optional[str] = Field(description="The path to the model")
class modelInfo(ModelInfo):
info: Annotated[Union[CkptModelInfo,DiffusersModelInfo], Field(discriminator="format")]
class CreateModelRequest (BaseModel): class CreateModelRequest (BaseModel):
name: str = Field(description="The name of the model") name: str = Field(description="The name of the model")
info: Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")] = Field(description="The model info") info: modelInfo = Field(description="The model details and configuration")
class CreateModelResponse (BaseModel): class CreateModelResponse (BaseModel):
name: str = Field(description="The name of the new model") name: str = Field(description="The name of the new model")
info: Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")] = Field(description="The model info") info: modelInfo = Field(description="The model details and configuration")
status: str = Field(description="The status of the API response")
class ModelsList(BaseModel): class ModelsList(BaseModel):
models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]] models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]]
@ -83,7 +83,7 @@ async def update_model(
model_attributes=model_request.info, model_attributes=model_request.info,
clobber=True, clobber=True,
) )
model_response = CreateModelResponse(status="success") model_response = CreateModelResponse(name=model_request.name, info=model_request.info, status="success")
except Exception as e: except Exception as e:
# Handle any exceptions thrown during the execution of the method # Handle any exceptions thrown during the execution of the method