优化代码
This commit is contained in:
@@ -820,7 +820,7 @@ class TTS:
|
||||
|
||||
def empty_cache(self):
|
||||
try:
|
||||
if str(self.configs.device) == "cuda":
|
||||
if "cuda" in str(self.configs.device):
|
||||
torch.cuda.empty_cache()
|
||||
elif str(self.configs.device) == "mps":
|
||||
torch.mps.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user