12
api_v3.py
12
api_v3.py
@@ -320,7 +320,7 @@ async def tts_handle(req: dict):
|
||||
try:
|
||||
tts_instance = get_tts_instance(tts_config)
|
||||
|
||||
move_to_gpu(tts_instance, tts_config)
|
||||
move_to_original(tts_instance, tts_config)
|
||||
|
||||
tts_generator = tts_instance.run(req)
|
||||
|
||||
@@ -347,13 +347,15 @@ async def tts_handle(req: dict):
|
||||
|
||||
def move_to_cpu(tts):
|
||||
cpu_device = torch.device('cpu')
|
||||
tts.set_device(cpu_device)
|
||||
tts.set_device(cpu_device, False)
|
||||
tts.enable_half_precision(False, False)
|
||||
print("Moved TTS models to CPU to save GPU memory.")
|
||||
|
||||
|
||||
def move_to_gpu(tts: TTS, tts_config: TTS_Config):
|
||||
tts.set_device(tts_config.device)
|
||||
print("Moved TTS models back to GPU for performance.")
|
||||
def move_to_original(tts: TTS, tts_config: TTS_Config):
|
||||
tts.set_device(tts_config.device, False)
|
||||
tts.enable_half_precision(tts_config.is_half, False)
|
||||
print("Moved TTS models back to original device for performance.")
|
||||
|
||||
|
||||
@APP.get("/control")
|
||||
|
||||
Reference in New Issue
Block a user