Browse Source

1、TTS服务(出版:仅在后台调用音频设备播放,未传到前端)

liyanbo 8 months ago
parent
commit
7862666f9e

+ 25 - 3
byzs-module-ai/src/main/java/cn/iocoder/byzs/module/ai/controller/admin/tts/StreamTtsService.java

@@ -18,6 +18,7 @@ import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.util.concurrent.ConcurrentLinkedQueue;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.function.Consumer;
 
 @Service
 @Slf4j
@@ -42,6 +43,8 @@ public class StreamTtsService {
     private PlaybackRunnable playbackRunnable;
     private Thread playbackThread;
     private StreamInputTts synthesizer;
+    // ==== 添加音频数据回调 ====
+    private Consumer<byte[]> audioDataCallback;
 
     @PostConstruct
     public void init() {
@@ -88,10 +91,12 @@ public class StreamTtsService {
         synthesizer.setAppKey(appKey);
         synthesizer.setFormat(OutputFormatEnum.PCM);
         synthesizer.setSampleRate(SampleRateEnum.SAMPLE_RATE_24K);
-        synthesizer.setVoice("longxiaochun");
+        synthesizer.setVoice("aitong");
         synthesizer.setVolume(50);
         synthesizer.setPitchRate(0);
-        synthesizer.setSpeechRate(0);
+        synthesizer.setSpeechRate(50);
+//        synthesizer.setSplitText(true); // 如有类似配置需设为false
+//        synthesizer.setEnableSplit(false); // 禁用文本拆分,确保整句合成
 
         try {
             synthesizer.startStreamInputTts();
@@ -173,7 +178,20 @@ public class StreamTtsService {
                 }
                 byte[] bytesArray = new byte[message.remaining()];
                 message.get(bytesArray, 0, bytesArray.length);
-                audioPlayer.put(ByteBuffer.wrap(bytesArray));
+//                audioPlayer.put(ByteBuffer.wrap(bytesArray));
+
+                // ==== 调用回调传递音频数据 ====
+                log.info("生成音频数据: 长度={} bytes", bytesArray.length);
+                if (audioDataCallback != null) {
+                    try {
+                        audioDataCallback.accept(bytesArray);
+                        log.info("音频数据已发送到SSE流");
+                    } catch (Exception e) {
+                        log.error("音频数据回调失败", e);
+                    }
+                } else {
+                    log.warn("音频数据回调未设置,无法发送音频数据");
+                }
             }
 
             @Override
@@ -193,6 +211,10 @@ public class StreamTtsService {
         };
     }
 
+    public void setAudioDataCallback(Consumer<byte[]> callback) {
+        this.audioDataCallback = callback;
+    }
+
     class PlaybackRunnable implements Runnable {
         private AudioFormat af;
         private DataLine.Info info;

+ 141 - 4
byzs-module-ai/src/main/java/cn/iocoder/byzs/module/ai/service/chat/AiChatMessageServiceImpl.java

@@ -42,10 +42,22 @@ import org.springframework.ai.chat.prompt.ChatOptions;
 import org.springframework.ai.chat.prompt.Prompt;
 import org.springframework.stereotype.Service;
 import org.springframework.transaction.annotation.Transactional;
+import reactor.core.publisher.ConnectableFlux;
+import reactor.core.publisher.DirectProcessor;
 import reactor.core.publisher.Flux;
+import reactor.core.publisher.FluxSink;
 
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
 import java.time.LocalDateTime;
 import java.util.*;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ScheduledFuture;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
 import java.util.stream.Collectors;
 
 import static cn.iocoder.byzs.framework.common.exception.util.ServiceExceptionUtil.exception;
@@ -170,12 +182,25 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO);
         Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
 
+        // 创建一个Processor用于合并文本和音频流
+        // 创建一个共享的响应对象
+        AtomicReference<CommonResult<AiChatMessageSendRespVO>> sharedResponse = new AtomicReference<>();
+
+        DirectProcessor<CommonResult<AiChatMessageSendRespVO>> processor = DirectProcessor.create();
+        FluxSink<CommonResult<AiChatMessageSendRespVO>> sink = processor.sink();
+
         // 4.3 初始化TTS服务
         streamTtsService.startTts();
 
         // 4.4 流式返回并处理TTS
         StringBuffer contentBuffer = new StringBuffer();
-        return streamResponse.map(chunk -> {
+        // 添加句子结束符正则表达式
+        Pattern sentencePattern = Pattern.compile("[。!?;.\n\r]"); // 增加换行符支持
+
+        ScheduledExecutorService scheduler = Executors.newSingleThreadScheduledExecutor();
+        AtomicReference<ScheduledFuture<?>> ttsTask = new AtomicReference<>();
+
+        Flux<CommonResult<AiChatMessageSendRespVO>> textStream  = streamResponse.map(chunk -> {
             // 处理知识库的返回,只有首次才有
             List<AiChatMessageRespVO.KnowledgeSegment> segments = null;
             if (StrUtil.isEmpty(contentBuffer)) {
@@ -194,19 +219,45 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
             contentBuffer.append(newContent);
 
             // 发送新内容到TTS服务进行语音合成
-            if (StrUtil.isNotBlank(newContent)) {
-                streamTtsService.sendText(newContent);
+            if (ttsTask.get() != null) {
+                ttsTask.get().cancel(false); // 取消之前的延迟任务
             }
-            return success(new AiChatMessageSendRespVO()
+            // 延迟500ms执行,合并短时间内到达的文本片段
+            ttsTask.set(scheduler.schedule(() -> {
+                Matcher matcher = sentencePattern.matcher(contentBuffer);
+                if (matcher.find()) {
+                    processCompleteSentence(contentBuffer, matcher);
+                } else if (contentBuffer.length() > 20) { // 最长20字未结束也处理
+                    processCompleteSentence(contentBuffer, contentBuffer.length());
+                }
+            }, 500, TimeUnit.MILLISECONDS));
+
+            CommonResult<AiChatMessageSendRespVO> result = success(new AiChatMessageSendRespVO()
+                    .setEventType("TEXT")
                     .setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
                     .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class)
                             .setContent(newContent).setSegments(segments)));
+            sharedResponse.set(result);
+            return result;
         }).doOnComplete(() -> {
+            if (contentBuffer.length() > 0) {
+                streamTtsService.sendText(contentBuffer.toString());
+                contentBuffer.setLength(0);
+            }
             // 忽略租户,因为 Flux 异步无法透传租户
             TenantUtils.executeIgnore(() -> chatMessageMapper.updateById(
                     new AiChatMessageDO().setId(assistantMessage.getId()).setContent(contentBuffer.toString())));
+
+
+            if (ttsTask.get() != null) {
+                ttsTask.get().cancel(false);
+            }
+            processRemainingText(contentBuffer); // 处理剩余文本
+            scheduler.shutdown(); // 关闭调度器
+
             // 通知TTS服务文本发送完成
             streamTtsService.stopTts();
+            sink.complete(); // 完成流
         }).doOnError(throwable -> {
             log.error("[sendChatMessageStream][userId({}) sendReqVO({}) 发生异常]", userId, sendReqVO, throwable);
             // 忽略租户,因为 Flux 异步无法透传租户
@@ -214,7 +265,61 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
                     new AiChatMessageDO().setId(assistantMessage.getId()).setContent(throwable.getMessage())));
             // 发生错误时停止TTS服务
             streamTtsService.stopTts();
+            // ==== 添加回调清理 ====
+            streamTtsService.setAudioDataCallback(null);
+            // =====================
+            sink.error(throwable); // 传递错误
+        })
+        // ==== 添加finally块清理 ====
+        .doFinally(signalType -> {
+            streamTtsService.setAudioDataCallback(null);
         }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR)));
+
+        // 创建音频流
+        Flux<CommonResult<AiChatMessageSendRespVO>> audioStream = Flux.create(sink2 -> {
+            streamTtsService.setAudioDataCallback(audioBytes -> {
+                try {
+                    // 确保TTS输出WAV格式(带文件头)
+                    byte[] wavAudioWithHeader = addWavHeader(audioBytes, 24000, 16, 1); // 修改为24kHz(匹配StreamTtsService设置)
+                    String base64Audio = Base64.getEncoder().encodeToString(wavAudioWithHeader);
+                    AiChatMessageSendRespVO audioResp = new AiChatMessageSendRespVO();
+                    audioResp.setEventType("AUDIO");
+                    audioResp.setAudioData(base64Audio);
+                    sink2.next(success(audioResp));
+                } catch (Exception e) {
+                    log.error("[------------][userId({}) sendReqVO({}) <UNK> 发生异常]", userId, sendReqVO, e);
+                    sink2.error(e);
+                }
+            });
+        });
+
+        // 合并文本流和音频流,使用mergeWith而非mergeSequential
+        return textStream.mergeWith(audioStream);
+    }
+
+    // 处理完整句子
+    private void processCompleteSentence(StringBuffer buffer, Matcher matcher) {
+        String sentence = buffer.substring(0, matcher.end());
+        streamTtsService.sendText(sentence);
+        buffer.delete(0, matcher.end());
+        System.out.println("TTS合成完整句: " + sentence);
+    }
+
+    // 处理指定长度文本
+    private void processCompleteSentence(StringBuffer buffer, int length) {
+        String sentence = buffer.substring(0, length);
+        streamTtsService.sendText(sentence);
+        buffer.delete(0, length);
+        System.out.println("TTS合成长文本: " + sentence);
+    }
+
+    // 处理剩余文本
+    private void processRemainingText(StringBuffer buffer) {
+        if (buffer.length() > 0) {
+            streamTtsService.sendText(buffer.toString());
+            buffer.setLength(0);
+            System.out.println("TTS合成剩余文本: " + buffer.toString());
+        }
     }
 
     private List<AiKnowledgeSegmentSearchRespBO> recallKnowledgeSegment(String content,
@@ -237,6 +342,38 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         return knowledgeSegments;
     }
 
+    private byte[] addWavHeader(byte[] pcmData, int sampleRate, int bitsPerSample, int channels) {
+        int byteRate = sampleRate * channels * bitsPerSample / 8;
+        int blockAlign = channels * bitsPerSample / 8;
+        int dataSize = pcmData.length;
+        int fileSize = 44 + dataSize; // 修正:WAV头标准大小为44字节
+
+        ByteBuffer buffer = ByteBuffer.allocate(fileSize); // 分配足够大小的缓冲区
+        buffer.order(ByteOrder.LITTLE_ENDIAN);
+
+        // RIFF chunk
+        buffer.put("RIFF".getBytes());
+        buffer.putInt(fileSize - 8); // 总文件大小 - 8
+        buffer.put("WAVE".getBytes());
+
+        // fmt subchunk
+        buffer.put("fmt ".getBytes());
+        buffer.putInt(16); // PCM格式子块大小固定为16
+        buffer.putShort((short) 1); // 线性PCM编码
+        buffer.putShort((short) channels);
+        buffer.putInt(sampleRate);
+        buffer.putInt(byteRate);
+        buffer.putShort((short) blockAlign);
+        buffer.putShort((short) bitsPerSample);
+
+        // data subchunk
+        buffer.put("data".getBytes());
+        buffer.putInt(dataSize);
+        buffer.put(pcmData); // 此时缓冲区有足够空间,不会溢出
+
+        return buffer.array();
+    }
+
     private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
                                List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments,
                                AiModelDO model, AiChatMessageSendReqVO sendReqVO) {

+ 2 - 2
byzs-server/src/main/resources/application.yaml

@@ -3,8 +3,8 @@ spring:
     name: byzs-bjdx
 
   profiles:
-#    active: local
-    active: prodDev
+    active: local
+#    active: prodDev
 #    active: prod
 
   main: