add fp16 support to controlnet models

This commit is contained in:
Lincoln Stein
2023-07-15 10:11:41 -04:00
committed by Kent Keirsey
parent 52948a1bbc
commit e01706f5f5
11 changed files with 342 additions and 347 deletions

View File

@ -1,8 +1,7 @@
import os
import torch
from enum import Enum
from pathlib import Path
from typing import Optional, Union, Literal
from typing import Optional
from .base import (
ModelBase,
ModelConfigBase,
@ -14,6 +13,7 @@ from .base import (
calc_model_size_by_data,
classproperty,
InvalidModelException,
ModelNotFoundException,
)
class ControlNetModelFormat(str, Enum):
@ -60,10 +60,20 @@ class ControlNetModel(ModelBase):
if child_type is not None:
raise Exception("There is no child models in controlnet model")
model = self.model_class.from_pretrained(
self.model_path,
torch_dtype=torch_dtype,
)
model = None
for variant in ['fp16',None]:
try:
model = self.model_class.from_pretrained(
self.model_path,
torch_dtype=torch_dtype,
variant=variant,
)
break
except:
pass
if not model:
raise ModelNotFoundException()
# calc more accurate size
self.model_size = calc_model_size_by_data(model)
return model