This commit is contained in:
SetoKaiba
2024-06-27 22:53:53 +08:00
committed by GitHub
parent 8dd7cfab93
commit 836bfec1fb
2 changed files with 13 additions and 9 deletions

View File

@@ -301,7 +301,7 @@ class TTS:
if self.configs.is_half and str(self.configs.device)!="cpu":
self.t2s_model = self.t2s_model.half()
def enable_half_precision(self, enable: bool = True):
def enable_half_precision(self, enable: bool = True, save: bool = True):
'''
To enable half precision for the TTS model.
Args:
@@ -314,7 +314,8 @@ class TTS:
self.configs.is_half = enable
self.precision = torch.float16 if enable else torch.float32
self.configs.save_configs()
if save:
self.configs.save_configs()
if enable:
if self.t2s_model is not None:
self.t2s_model =self.t2s_model.half()
@@ -334,14 +335,15 @@ class TTS:
if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.float()
def set_device(self, device: torch.device):
def set_device(self, device: torch.device, save: bool = True):
'''
To set the device for all models.
Args:
device: torch.device, the device to use for all models.
'''
self.configs.device = device
self.configs.save_configs()
if save:
self.configs.save_configs()
if self.t2s_model is not None:
self.t2s_model = self.t2s_model.to(device)
if self.vits_model is not None: