From 345f3203f84d6017151f1075bed0e917ac784130 Mon Sep 17 00:00:00 2001 From: chasonjiang <1440499136@qq.com> Date: Tue, 12 Mar 2024 16:08:50 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=BA=86=E7=83=AD=E5=88=87?= =?UTF-8?q?=E6=8D=A2=E6=A8=A1=E5=9E=8B=E6=97=B6=EF=BC=8C=E7=B2=BE=E5=BA=A6?= =?UTF-8?q?=E4=B8=8D=E5=8C=B9=E9=85=8D=E5=AF=BC=E8=87=B4=E7=9A=84=E9=94=99?= =?UTF-8?q?=E8=AF=AF=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/TTS_infer_pack/TTS.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 694d4a7..61ba7be 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -206,7 +206,7 @@ class TTS: self.init_vits_weights(self.configs.vits_weights_path) self.init_bert_weights(self.configs.bert_base_path) self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path) - self.enable_half_precision(self.configs.is_half) + # self.enable_half_precision(self.configs.is_half) @@ -215,6 +215,8 @@ class TTS: self.cnhuhbert_model = CNHubert(base_path) self.cnhuhbert_model=self.cnhuhbert_model.eval() self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device) + if self.configs.is_half: + self.cnhuhbert_model = self.cnhuhbert_model.half() @@ -224,6 +226,8 @@ class TTS: self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path) self.bert_model=self.bert_model.eval() self.bert_model = self.bert_model.to(self.configs.device) + if self.configs.is_half: + self.bert_model = self.bert_model.half() @@ -255,6 +259,8 @@ class TTS: vits_model = vits_model.eval() vits_model.load_state_dict(dict_s2["weight"], strict=False) self.vits_model = vits_model + if self.configs.is_half: + self.vits_model = self.vits_model.half() def init_t2s_weights(self, weights_path: str): @@ -271,6 +277,8 @@ class TTS: t2s_model = t2s_model.to(self.configs.device) t2s_model = t2s_model.eval() self.t2s_model = t2s_model + if self.configs.is_half: + self.t2s_model = self.t2s_model.half() def enable_half_precision(self, enable: bool = True): '''