diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index bcab7e8..e1ed8bb 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -820,7 +820,7 @@ class TTS: def empty_cache(self): try: - if str(self.configs.device) == "cuda": + if "cuda" in str(self.configs.device): torch.cuda.empty_cache() elif str(self.configs.device) == "mps": torch.mps.empty_cache()