فهرست منبع

优化课程问答中指定答案的语音合成兼容豆包tts

liyanbo 2 هفته پیش
والد
کامیت
754d2bc409

+ 4 - 4
byzs-module-ai/src/main/java/cn/iocoder/byzs/module/ai/service/tts/DouBaoTtsService.java

@@ -32,7 +32,7 @@ public class DouBaoTtsService {
      */
     public byte[] convertTextToSpeech(AiTtsDO aiTtsDO, String content, String command) throws IOException {
         // 构建请求
-        Request request = buildTtsRequest(aiTtsDO, content, command, OutputFormatEnum.MP3);
+        Request request = buildTtsRequest(aiTtsDO, content, command, OutputFormatEnum.MP3.getName());
         
         // 发送请求并处理响应
         try (Response response = new OkHttpClient().newCall(request).execute()) {
@@ -95,7 +95,7 @@ public class DouBaoTtsService {
      */
     public void streamTextToSpeech(AiTtsDO aiTtsDO, String content, String command, Consumer<byte[]> audioDataCallback) throws IOException {
         // 构建请求
-        Request request = buildTtsRequest(aiTtsDO, content, command, OutputFormatEnum.PCM);
+        Request request = buildTtsRequest(aiTtsDO, content, command, OutputFormatEnum.PCM.getName());
         
         // 发送请求并处理响应
         try (Response response = new OkHttpClient().newCall(request).execute()) {
@@ -145,7 +145,7 @@ public class DouBaoTtsService {
     /**
      * 构建TTS请求
      */
-    private Request buildTtsRequest(AiTtsDO aiTtsDO, String content, String command, OutputFormatEnum format) throws IOException {
+    private Request buildTtsRequest(AiTtsDO aiTtsDO, String content, String command, String format) throws IOException {
         // 获取配置
         ByzsAiProperties.DouBaoProperties doubaoProperties = byzsAiProperties.getDoubao();
         if (doubaoProperties == null) {
@@ -174,7 +174,7 @@ public class DouBaoTtsService {
         
         Map<String, Object> audioParams = new HashMap<>();
         audioParams.put("format", format); // 输出音频格式
-        audioParams.put("sample_rate", SampleRateEnum.SAMPLE_RATE_16K); // 推荐采样率
+        audioParams.put("sample_rate", SampleRateEnum.SAMPLE_RATE_16K.value); // 推荐采样率
         if (aiTtsDO.getEmotion() != null && !aiTtsDO.getEmotion().isEmpty()) {
             audioParams.put("emotion", aiTtsDO.getEmotion());
         }

+ 372 - 14
byzs-web/src/main/java/cn/iocoder/byzs/module/web/service/ai/WebAiServiceImpl.java

@@ -13,7 +13,9 @@ import cn.iocoder.byzs.module.ai.dal.mysql.tts.AiTtsMapper;
 import cn.iocoder.byzs.module.ai.enums.ErrorCodeConstants;
 import cn.iocoder.byzs.module.ai.service.chat.AiChatConversationService;
 import cn.iocoder.byzs.module.ai.service.model.AiChatRoleService;
+import cn.iocoder.byzs.module.ai.service.tts.DouBaoTtsService;
 import cn.iocoder.byzs.module.ai.util.tts.StreamingAliyunTtsService;
+import cn.iocoder.byzs.module.ai.util.tts.StreamingDouBaoTtsService;
 import cn.iocoder.byzs.module.ai.util.tts.WavHeader;
 import com.alibaba.nls.client.protocol.SampleRateEnum;
 import jakarta.annotation.Resource;
@@ -32,6 +34,7 @@ import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.ScheduledFuture;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
@@ -62,6 +65,18 @@ public class WebAiServiceImpl {
     @Resource
     private ObjectProvider<StreamingAliyunTtsService> streamTtsServiceProvider;
 
+    @Resource
+    private StreamingDouBaoTtsService streamingDouBaoTtsService;
+
+    // 豆包TTS的sink引用,用于发送音频数据
+    private AtomicReference<FluxSink<CommonResult<AiChatMessageSendRespVO>>> douBaoSinkRef;
+    
+    // 标记是否是首次发送豆包TTS音频数据
+    private final AtomicBoolean isFirstDouBaoAudio = new AtomicBoolean(true);
+    
+    // 豆包TTS任务计数器
+    private final AtomicInteger douBaoTtsTaskCount = new AtomicInteger(0);
+
     /**
      * 发送指定回答的SSE流式响应
      * 确保TEXT类型文本数据优先且可靠地发送到前端,同时提供AUDIO音频流
@@ -81,17 +96,37 @@ public class WebAiServiceImpl {
         String contentAnswer = sendReqVO.getContentAnswer();
         log.info("开始处理文本内容: {}", contentAnswer);
 
-        // 4. 创建TTS服务实例
-        StreamingAliyunTtsService streamingAliyunTtsService = streamTtsServiceProvider.getObject();
+        // 4. 检查是否为豆包TTS
+        boolean isDouBaoTts = aiTtsDO != null && "DouBao".equals(aiTtsDO.getPlatform());
+
+        // 5. 创建TTS服务实例
+        StreamingAliyunTtsService streamingAliyunTtsService = null;
+        if (!isDouBaoTts) {
+            streamingAliyunTtsService = streamTtsServiceProvider.getObject();
+        }
 
         try {
-            // 5. 初始化TTS服务
-            streamingAliyunTtsService.startTts(aiTtsDO);
-            return createSseFlux(sendReqVO, userId, conversation, contentAnswer, streamingAliyunTtsService);
+            // 6. 初始化TTS服务
+            if (isDouBaoTts) {
+                // 初始化豆包TTS服务
+                streamingDouBaoTtsService.startTts(aiTtsDO);
+                // 重置豆包TTS的sink引用
+                this.douBaoSinkRef = new AtomicReference<>();
+                // 重置豆包TTS的首次音频标记
+                isFirstDouBaoAudio.set(true);
+            } else {
+                // 初始化阿里云TTS服务
+                streamingAliyunTtsService.startTts(aiTtsDO);
+            }
+            return createSseFlux(sendReqVO, userId, conversation, contentAnswer, streamingAliyunTtsService, isDouBaoTts, aiTtsDO);
         } catch (Exception e) {
             log.error("发送指定回答失败", e);
             AtomicBoolean tempTtsStopped = new AtomicBoolean(false);
-            cleanupTtsResources(streamingAliyunTtsService, tempTtsStopped);
+            if (isDouBaoTts) {
+                cleanupDouBaoTtsResources(tempTtsStopped);
+            } else {
+                cleanupTtsResources(streamingAliyunTtsService, tempTtsStopped);
+            }
             // 即使发生异常,也要返回文本数据,确保前端至少能收到文本
             return Flux.just(createFallbackTextResponse(sendReqVO, userId, conversation, contentAnswer));
         }
@@ -117,7 +152,7 @@ public class WebAiServiceImpl {
      */
     private Flux<CommonResult<AiChatMessageSendRespVO>> createSseFlux(
             AiChatMessageSendReqVO sendReqVO, Long userId, AiChatConversationDO conversation,
-            String contentAnswer, StreamingAliyunTtsService streamingAliyunTtsService) {
+            String contentAnswer, StreamingAliyunTtsService streamingAliyunTtsService, boolean isDouBaoTts, AiTtsDO aiTtsDO) {
         return Flux.<CommonResult<AiChatMessageSendRespVO>>create(sink -> {
                     // 初始化句子处理相关组件
                     ScheduledExecutorService scheduler = Executors.newSingleThreadScheduledExecutor(r -> {
@@ -134,19 +169,54 @@ public class WebAiServiceImpl {
                         // 发送文本数据(带type)
                         sendTextData(sink, sendReqVO, userId, conversation, contentAnswer);
 
-                        // 创建音频流并订阅
-                        createAndSubscribeToAudioStream(sink, streamingAliyunTtsService, scheduler, ttsTask, ttsStopped);
+                        // 处理音频流
+                        if (isDouBaoTts) {
+                            // 为豆包TTS设置sink
+                            this.douBaoSinkRef.set(sink);
+                        } else {
+                            // 为阿里云TTS设置音频数据回调
+                            AtomicBoolean isFirstChunk = new AtomicBoolean(true); // 首包标志位
+                            streamingAliyunTtsService.setAudioDataCallback(audioBytes -> {
+                                try {
+                                    byte[] processedAudio = processAudioData(audioBytes, isFirstChunk);
+                                    String base64Audio = java.util.Base64.getEncoder().encodeToString(processedAudio);
+
+                                    AiChatMessageSendRespVO audioResp = new AiChatMessageSendRespVO();
+                                    audioResp.setEventType("AUDIO");
+                                    audioResp.setAudioData(base64Audio);
+                                    sink.next(success(audioResp));
+                                } catch (Exception e) {
+                                    log.error("[TTS处理异常] 音频编码失败", e);
+                                    sink.error(new RuntimeException("TTS音频处理失败: " + e.getMessage(), e));
+                                }
+                            });
+
+                            // 设置完成回调
+                            streamingAliyunTtsService.setOnCompleteCallback(() -> {
+                                log.info("TTS转换完成,准备终止SSE流");
+                                // 不在这里调用sink.complete(),因为文本处理可能还在进行
+                            });
+                        }
 
                         // 开始处理文本分段并发送到TTS
                         Pattern sentencePattern = Pattern.compile("[。!?;\n\r]");
-                        processTextSegments(streamingAliyunTtsService, contentTTSBuffer, sentencePattern,
-                                scheduler, ttsTask, ttsStopped, allTextProcessed, sink);
+                        if (isDouBaoTts) {
+                            processTextSegmentsForDouBao(aiTtsDO, contentTTSBuffer, sentencePattern,
+                                    scheduler, ttsTask, ttsStopped, allTextProcessed, sink);
+                        } else {
+                            processTextSegments(streamingAliyunTtsService, contentTTSBuffer, sentencePattern,
+                                    scheduler, ttsTask, ttsStopped, allTextProcessed, sink);
+                        }
 
                         // 添加超时检测(60秒)
                         ScheduledFuture<?> timeoutTask = scheduler.schedule(() -> {
                             if (!ttsStopped.get()) {
                                 log.warn("TTS处理超时,强制终止SSE流");
-                                cleanupResources(streamingAliyunTtsService, scheduler, ttsTask, ttsStopped, sink);
+                                if (isDouBaoTts) {
+                                    cleanupDouBaoResources(scheduler, ttsTask, ttsStopped, sink);
+                                } else {
+                                    cleanupResources(streamingAliyunTtsService, scheduler, ttsTask, ttsStopped, sink);
+                                }
                             }
                         }, 60, TimeUnit.SECONDS);
 
@@ -156,12 +226,20 @@ public class WebAiServiceImpl {
                             if (timeoutTask != null) {
                                 timeoutTask.cancel(false);
                             }
-                            cleanupResources(streamingAliyunTtsService, scheduler, ttsTask, ttsStopped, sink);
+                            if (isDouBaoTts) {
+                                cleanupDouBaoResources(scheduler, ttsTask, ttsStopped, sink);
+                            } else {
+                                cleanupResources(streamingAliyunTtsService, scheduler, ttsTask, ttsStopped, sink);
+                            }
                         });
 
                     } catch (Exception e) {
                         log.error("创建SSE流异常", e);
-                        cleanupResources(streamingAliyunTtsService, scheduler, ttsTask, ttsStopped, sink);
+                        if (isDouBaoTts) {
+                            cleanupDouBaoResources(scheduler, ttsTask, ttsStopped, sink);
+                        } else {
+                            cleanupResources(streamingAliyunTtsService, scheduler, ttsTask, ttsStopped, sink);
+                        }
                         sink.error(e);
                     }
                 }).subscribeOn(Schedulers.boundedElastic())
@@ -438,6 +516,286 @@ public class WebAiServiceImpl {
         log.info("TTS合成长文本: {}", sentence);
     }
 
+    /**
+     * 创建并订阅豆包TTS音频流
+     */
+    private void createAndSubscribeToDouBaoAudioStream(FluxSink<CommonResult<AiChatMessageSendRespVO>> mainSink,
+                                                      ScheduledExecutorService scheduler,
+                                                      AtomicReference<ScheduledFuture<?>> ttsTask,
+                                                      AtomicBoolean ttsStopped, AiTtsDO aiTtsDO) {
+        Flux.<CommonResult<AiChatMessageSendRespVO>>create(audioSink -> {
+            // 保存豆包TTS的sink引用
+            this.douBaoSinkRef.set(audioSink);
+        }).subscribe(
+                chunk -> {
+                    if (!mainSink.isCancelled()) {
+                        mainSink.next(chunk);
+                    }
+                },
+                error -> {
+                    log.error("豆包TTS音频流处理异常", error);
+                    try {
+                        if (!mainSink.isCancelled()) {
+                            mainSink.error(error);
+                        }
+                    } catch (Exception e) {
+                        log.error("主SSE流设置错误异常", e);
+                    }
+                    cleanupDouBaoResources(scheduler, ttsTask, ttsStopped, mainSink);
+                },
+                () -> {
+                    log.info("豆包TTS音频流处理完成,准备终止主SSE流");
+                    // 音频流完成时,如果主SSE流还未终止,主动终止
+                    try {
+                        if (!mainSink.isCancelled()) {
+                            log.info("豆包TTS音频流完成后主动终止主SSE流");
+                            mainSink.complete();
+                            log.info("主SSE流已成功终止");
+                        }
+                    } catch (Exception e) {
+                        log.error("终止主SSE流异常", e);
+                    }
+                }
+        );
+    }
+
+    /**
+     * 处理豆包TTS的文本分段
+     */
+    private void processTextSegmentsForDouBao(AiTtsDO aiTtsDO, StringBuilder buffer,
+                                             Pattern sentencePattern, ScheduledExecutorService scheduler,
+                                             AtomicReference<ScheduledFuture<?>> ttsTask, AtomicBoolean ttsStopped,
+                                             AtomicBoolean allTextProcessed, FluxSink<CommonResult<AiChatMessageSendRespVO>> sink) {
+        if (buffer.isEmpty()) {
+            log.info("文本为空,无需处理");
+            handleDouBaoTextComplete(aiTtsDO, scheduler, ttsStopped, allTextProcessed, sink);
+            return;
+        }
+
+        // 立即处理文本
+        Matcher matcher = sentencePattern.matcher(buffer);
+        if (matcher.find()) {
+            processDouBaoCompleteSentence(aiTtsDO, buffer, matcher);
+            // 继续调度处理剩余文本
+            scheduleDouBaoNextProcessing(aiTtsDO, buffer, sentencePattern, scheduler,
+                    ttsTask, ttsStopped, allTextProcessed, sink);
+        } else if (buffer.length() > 50) { // 最长50字未结束也处理
+            processDouBaoCompleteSentence(aiTtsDO, buffer, buffer.length());
+            // 继续调度处理剩余文本
+            scheduleDouBaoNextProcessing(aiTtsDO, buffer, sentencePattern, scheduler,
+                    ttsTask, ttsStopped, allTextProcessed, sink);
+        } else {
+            // 文本较短且未结束,直接处理全部
+            log.info("豆包TTS合成短文本: {}", buffer.toString());
+            processDouBaoTts(aiTtsDO, buffer.toString());
+            buffer.setLength(0);
+            handleDouBaoTextComplete(aiTtsDO, scheduler, ttsStopped, allTextProcessed, sink);
+        }
+    }
+
+    /**
+     * 调度豆包TTS下一次文本处理
+     */
+    private void scheduleDouBaoNextProcessing(AiTtsDO aiTtsDO, StringBuilder buffer,
+                                             Pattern sentencePattern, ScheduledExecutorService scheduler,
+                                             AtomicReference<ScheduledFuture<?>> ttsTask, AtomicBoolean ttsStopped,
+                                             AtomicBoolean allTextProcessed, FluxSink<CommonResult<AiChatMessageSendRespVO>> sink) {
+        if (!buffer.isEmpty()) {
+            // 延迟200ms执行,合并短时间内处理的文本片段
+            if (ttsTask.get() != null) {
+                ttsTask.get().cancel(false); // 取消之前的延迟任务
+            }
+            ttsTask.set(scheduler.schedule(() -> {
+                processTextSegmentsForDouBao(aiTtsDO, buffer, sentencePattern, scheduler,
+                        ttsTask, ttsStopped, allTextProcessed, sink);
+            }, 200, TimeUnit.MILLISECONDS));
+        } else {
+            // 所有文本处理完毕
+            handleDouBaoTextComplete(aiTtsDO, scheduler, ttsStopped, allTextProcessed, sink);
+        }
+    }
+
+    /**
+     * 处理豆包TTS文本完成后的逻辑
+     */
+    private void handleDouBaoTextComplete(AiTtsDO aiTtsDO, ScheduledExecutorService scheduler,
+                                         AtomicBoolean ttsStopped, AtomicBoolean allTextProcessed,
+                                         FluxSink<CommonResult<AiChatMessageSendRespVO>> sink) {
+        allTextProcessed.set(true);
+        log.info("所有文本处理完毕,准备通知TTS服务文本已发送完毕");
+
+        // 添加额外的超时检测,作为最后的保障
+        scheduler.schedule(() -> {
+            if (allTextProcessed.get() && !ttsStopped.get()) {
+                log.info("所有文本已发送到TTS服务,但TTS未完成回调,主动终止TTS服务和SSE流");
+                try {
+                    // 设置标志位
+                    ttsStopped.set(true);
+                    // 尝试停止TTS服务
+                    try {
+                        streamingDouBaoTtsService.stopTts();
+                        log.info("超时后主动停止豆包TTS服务完成");
+                    } catch (Exception e) {
+                        log.error("停止豆包TTS服务异常: {}", e.getMessage());
+                        // 即使停止失败,也要确保SSE流终止
+                    }
+
+                    // 确保SSE流终止,无论TTS服务是否成功停止
+                    if (sink != null && !sink.isCancelled()) {
+                        log.info("在handleDouBaoTextComplete超时检测中主动终止主SSE流");
+                        sink.complete();
+                        log.info("主SSE流已成功终止");
+                    }
+                } catch (Exception e) {
+                    log.error("超时检测处理异常", e);
+                }
+            }
+        }, 60, TimeUnit.SECONDS); // 等待60秒后检查
+    }
+
+    /**
+     * 处理豆包TTS完整句子
+     */
+    private void processDouBaoCompleteSentence(AiTtsDO aiTtsDO, StringBuilder buffer, Matcher matcher) {
+        String sentence = buffer.substring(0, matcher.end());
+        processDouBaoTts(aiTtsDO, sentence);
+        buffer.delete(0, matcher.end());
+        log.info("豆包TTS合成完整句: {}", sentence);
+    }
+
+    /**
+     * 处理豆包TTS指定长度文本
+     */
+    private void processDouBaoCompleteSentence(AiTtsDO aiTtsDO, StringBuilder buffer, int length) {
+        String sentence = buffer.substring(0, length);
+        processDouBaoTts(aiTtsDO, sentence);
+        buffer.delete(0, length);
+        log.info("豆包TTS合成长文本: {}", sentence);
+    }
+
+    /**
+     * 处理豆包TTS合成
+     */
+    private void processDouBaoTts(AiTtsDO aiTtsDO, String text) {
+        if (text == null || text.trim().isEmpty()) {
+            return;
+        }
+
+        // 增加任务计数
+        douBaoTtsTaskCount.incrementAndGet();
+
+        // 在单独的线程中处理豆包TTS,避免阻塞主线程
+        new Thread(() -> {
+            try {
+                // 使用豆包TTS流式服务进行处理
+                streamingDouBaoTtsService.sendText(aiTtsDO, text, audioBytes -> {
+                    // 处理音频数据
+                    if (audioBytes != null && audioBytes.length > 0) {
+                        // 豆包TTS现在返回的是PCM格式,需要添加WAV头
+                        byte[] processedAudio;
+                        if (isFirstDouBaoAudio.compareAndSet(true, false)) {
+                            // 首次音频数据:添加WAV头以符合前端期望
+                            processedAudio = WavHeader.addWavHeader(audioBytes, SampleRateEnum.SAMPLE_RATE_16K.value, 16, 1);
+                            log.info("豆包TTS首次音频合成成功,添加WAV头,原始长度: {} bytes,处理后长度: {} bytes", 
+                                    audioBytes.length, processedAudio.length);
+                        } else {
+                            // 后续音频数据:直接使用原始PCM数据
+                            processedAudio = audioBytes;
+                            log.info("豆包TTS后续音频合成成功,长度: {} bytes", processedAudio.length);
+                        }
+
+                        String base64Audio = java.util.Base64.getEncoder().encodeToString(processedAudio);
+
+                        // 创建音频响应对象
+                        AiChatMessageSendRespVO audioResp = new AiChatMessageSendRespVO();
+                        audioResp.setEventType("AUDIO");
+                        audioResp.setAudioData(base64Audio);
+
+                        log.info("豆包TTS合成成功");
+
+                        // 将音频数据发送到前端
+                        if (this.douBaoSinkRef != null && this.douBaoSinkRef.get() != null) {
+                            try {
+                                this.douBaoSinkRef.get().next(success(audioResp));
+                            } catch (Exception e) {
+                                log.error("发送豆包TTS音频数据失败", e);
+                            }
+                        }
+                    }
+                });
+            } catch (Exception e) {
+                log.error("豆包TTS合成失败", e);
+            } finally {
+                // 减少任务计数,当所有任务完成时关闭音频流
+                if (douBaoTtsTaskCount.decrementAndGet() == 0) {
+                    log.info("所有豆包TTS任务已完成,关闭音频流");
+                    if (this.douBaoSinkRef != null && this.douBaoSinkRef.get() != null) {
+                        try {
+                            this.douBaoSinkRef.get().complete();
+                            // 停止豆包TTS服务
+                            streamingDouBaoTtsService.stopTts();
+                        } catch (Exception e) {
+                            log.error("关闭豆包TTS音频流失败", e);
+                        }
+                    }
+                }
+            }
+        }).start();
+    }
+
+    /**
+     * 清理豆包TTS资源
+     */
+    private void cleanupDouBaoResources(ScheduledExecutorService scheduler,
+                                        AtomicReference<ScheduledFuture<?>> ttsTask,
+                                        AtomicBoolean ttsStopped,
+                                        FluxSink<CommonResult<AiChatMessageSendRespVO>> sink) {
+        if (ttsTask != null && ttsTask.get() != null) {
+            ttsTask.get().cancel(false);
+        }
+        if (scheduler != null && !scheduler.isShutdown()) {
+            scheduler.shutdownNow();
+        }
+        AtomicBoolean tempTtsStopped = new AtomicBoolean(false);
+        cleanupDouBaoTtsResources(tempTtsStopped);
+        ttsStopped.set(true);
+
+        // 确保SSE流终止
+        try {
+            if (sink != null && !sink.isCancelled()) {
+                log.info("在cleanupDouBaoResources中终止主SSE流");
+                sink.complete();
+                log.info("主SSE流已成功终止");
+            }
+        } catch (Exception e) {
+            log.error("在cleanupDouBaoResources中终止主SSE流异常", e);
+        }
+    }
+
+    /**
+     * 清理豆包TTS服务资源
+     */
+    private void cleanupDouBaoTtsResources(AtomicBoolean ttsStopped) {
+        try {
+            log.info("开始清理豆包TTS服务资源");
+
+            // 检查ttsStopped标志位,避免重复停止TTS服务
+            if (!ttsStopped.get()) {
+                try {
+                    log.info("调用streamingDouBaoTtsService.stopTts()");
+                    streamingDouBaoTtsService.stopTts();
+                    log.info("streamingDouBaoTtsService.stopTts()调用完成");
+                } catch (Exception e) {
+                    log.error("停止豆包TTS服务异常", e);
+                }
+            } else {
+                log.info("豆包TTS服务已停止,跳过stopTts()  调用");
+            }
+        } catch (Exception e) {
+            log.error("清理豆包TTS资源异常", e);
+        }
+    }
+
     /**
      * 清理所有资源
      */