增加健壮性,防止在cpu推理时设置半精度报错
This commit is contained in:
@@ -228,7 +228,7 @@ class TTS:
|
|||||||
self.cnhuhbert_model = CNHubert(base_path)
|
self.cnhuhbert_model = CNHubert(base_path)
|
||||||
self.cnhuhbert_model=self.cnhuhbert_model.eval()
|
self.cnhuhbert_model=self.cnhuhbert_model.eval()
|
||||||
self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device)
|
self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device)
|
||||||
if self.configs.is_half:
|
if self.configs.is_half and str(self.configs.device)!="cpu":
|
||||||
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
||||||
|
|
||||||
|
|
||||||
@@ -239,7 +239,7 @@ class TTS:
|
|||||||
self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path)
|
self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path)
|
||||||
self.bert_model=self.bert_model.eval()
|
self.bert_model=self.bert_model.eval()
|
||||||
self.bert_model = self.bert_model.to(self.configs.device)
|
self.bert_model = self.bert_model.to(self.configs.device)
|
||||||
if self.configs.is_half:
|
if self.configs.is_half and str(self.configs.device)!="cpu":
|
||||||
self.bert_model = self.bert_model.half()
|
self.bert_model = self.bert_model.half()
|
||||||
|
|
||||||
|
|
||||||
@@ -272,7 +272,7 @@ class TTS:
|
|||||||
vits_model = vits_model.eval()
|
vits_model = vits_model.eval()
|
||||||
vits_model.load_state_dict(dict_s2["weight"], strict=False)
|
vits_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||||
self.vits_model = vits_model
|
self.vits_model = vits_model
|
||||||
if self.configs.is_half:
|
if self.configs.is_half and str(self.configs.device)!="cpu":
|
||||||
self.vits_model = self.vits_model.half()
|
self.vits_model = self.vits_model.half()
|
||||||
|
|
||||||
|
|
||||||
@@ -290,7 +290,7 @@ class TTS:
|
|||||||
t2s_model = t2s_model.to(self.configs.device)
|
t2s_model = t2s_model.to(self.configs.device)
|
||||||
t2s_model = t2s_model.eval()
|
t2s_model = t2s_model.eval()
|
||||||
self.t2s_model = t2s_model
|
self.t2s_model = t2s_model
|
||||||
if self.configs.is_half:
|
if self.configs.is_half and str(self.configs.device)!="cpu":
|
||||||
self.t2s_model = self.t2s_model.half()
|
self.t2s_model = self.t2s_model.half()
|
||||||
|
|
||||||
def enable_half_precision(self, enable: bool = True):
|
def enable_half_precision(self, enable: bool = True):
|
||||||
@@ -300,7 +300,7 @@ class TTS:
|
|||||||
enable: bool, whether to enable half precision.
|
enable: bool, whether to enable half precision.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
if self.configs.device == "cpu" and enable:
|
if str(self.configs.device) == "cpu" and enable:
|
||||||
print("Half precision is not supported on CPU.")
|
print("Half precision is not supported on CPU.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user