mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
this fixes the inconsistent use of self.device, sometimes a str and sometimes an obj
This commit is contained in:
parent
dc30adfbb4
commit
01e05a98de
@ -133,31 +133,31 @@ class T2I:
|
|||||||
embedding_path=None,
|
embedding_path=None,
|
||||||
# just to keep track of this parameter when regenerating prompt
|
# just to keep track of this parameter when regenerating prompt
|
||||||
latent_diffusion_weights=False,
|
latent_diffusion_weights=False,
|
||||||
device='cuda',
|
|
||||||
):
|
):
|
||||||
self.iterations = iterations
|
self.iterations = iterations
|
||||||
self.width = width
|
self.width = width
|
||||||
self.height = height
|
self.height = height
|
||||||
self.steps = steps
|
self.steps = steps
|
||||||
self.cfg_scale = cfg_scale
|
self.cfg_scale = cfg_scale
|
||||||
self.weights = weights
|
self.weights = weights
|
||||||
self.config = config
|
self.config = config
|
||||||
self.sampler_name = sampler_name
|
self.sampler_name = sampler_name
|
||||||
self.latent_channels = latent_channels
|
self.latent_channels = latent_channels
|
||||||
self.downsampling_factor = downsampling_factor
|
self.downsampling_factor = downsampling_factor
|
||||||
self.grid = grid
|
self.grid = grid
|
||||||
self.ddim_eta = ddim_eta
|
self.ddim_eta = ddim_eta
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
self.full_precision = full_precision
|
self.full_precision = full_precision
|
||||||
self.strength = strength
|
self.strength = strength
|
||||||
self.embedding_path = embedding_path
|
self.embedding_path = embedding_path
|
||||||
self.model = None # empty for now
|
self.model = None # empty for now
|
||||||
self.sampler = None
|
self.sampler = None
|
||||||
|
self.device = None
|
||||||
self.latent_diffusion_weights = latent_diffusion_weights
|
self.latent_diffusion_weights = latent_diffusion_weights
|
||||||
self.device = device
|
|
||||||
|
|
||||||
# for VRAM usage statistics
|
# for VRAM usage statistics
|
||||||
self.session_peakmem = torch.cuda.max_memory_allocated() if self.device == 'cuda' else None
|
device_type = choose_torch_device()
|
||||||
|
self.session_peakmem = torch.cuda.max_memory_allocated() if device_type == 'cuda' else None
|
||||||
|
|
||||||
if seed is None:
|
if seed is None:
|
||||||
self.seed = self._new_seed()
|
self.seed = self._new_seed()
|
||||||
@ -251,14 +251,15 @@ class T2I:
|
|||||||
to create the requested output directory, select a unique informative name for each image, and
|
to create the requested output directory, select a unique informative name for each image, and
|
||||||
write the prompt into the PNG metadata.
|
write the prompt into the PNG metadata.
|
||||||
"""
|
"""
|
||||||
steps = steps or self.steps
|
# TODO: convert this into a getattr() loop
|
||||||
seed = seed or self.seed
|
steps = steps or self.steps
|
||||||
width = width or self.width
|
seed = seed or self.seed
|
||||||
height = height or self.height
|
width = width or self.width
|
||||||
cfg_scale = cfg_scale or self.cfg_scale
|
height = height or self.height
|
||||||
ddim_eta = ddim_eta or self.ddim_eta
|
cfg_scale = cfg_scale or self.cfg_scale
|
||||||
iterations = iterations or self.iterations
|
ddim_eta = ddim_eta or self.ddim_eta
|
||||||
strength = strength or self.strength
|
iterations = iterations or self.iterations
|
||||||
|
strength = strength or self.strength
|
||||||
self.log_tokenization = log_tokenization
|
self.log_tokenization = log_tokenization
|
||||||
|
|
||||||
model = (
|
model = (
|
||||||
@ -279,7 +280,7 @@ class T2I:
|
|||||||
self._set_sampler()
|
self._set_sampler()
|
||||||
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
torch.cuda.reset_peak_memory_stats() if self.device == 'cuda' else None
|
torch.cuda.reset_peak_memory_stats() if self.device.type == 'cuda' else None
|
||||||
results = list()
|
results = list()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -60,7 +60,6 @@ def main():
|
|||||||
# this is solely for recreating the prompt
|
# this is solely for recreating the prompt
|
||||||
latent_diffusion_weights=opt.laion400m,
|
latent_diffusion_weights=opt.laion400m,
|
||||||
embedding_path=opt.embedding_path,
|
embedding_path=opt.embedding_path,
|
||||||
device=opt.device,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# make sure the output directory exists
|
# make sure the output directory exists
|
||||||
@ -376,13 +375,6 @@ def create_argv_parser():
|
|||||||
type=str,
|
type=str,
|
||||||
help='Path to a pre-trained embedding manager checkpoint - can only be set on command line',
|
help='Path to a pre-trained embedding manager checkpoint - can only be set on command line',
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
'--device',
|
|
||||||
'-d',
|
|
||||||
type=str,
|
|
||||||
default='cuda',
|
|
||||||
help='Device to run Stable Diffusion on. Defaults to cuda `torch.cuda.current_device()` if avalible',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--prompt_as_dir',
|
'--prompt_as_dir',
|
||||||
'-p',
|
'-p',
|
||||||
|
Loading…
Reference in New Issue
Block a user