Преглед на файлове

1、ai对话TTS服务新增提供固定答案,转语音

liyanbo преди 8 месеца
родител
ревизия
0b1cb24ff4

+ 3 - 0
byzs-module-ai/src/main/java/cn/iocoder/byzs/module/ai/controller/admin/chat/vo/message/AiChatMessageSendReqVO.java

@@ -22,4 +22,7 @@ public class AiChatMessageSendReqVO {
     @Schema(description = "是否携带上下文", example = "true")
     private Boolean useContext;
 
+    @Schema(description = "携带答案", example = "true")
+    private String contentAnswer;
+
 }

+ 2 - 33
byzs-module-ai/src/main/java/cn/iocoder/byzs/module/ai/service/chat/AiChatMessageServiceImpl.java

@@ -31,6 +31,7 @@ import cn.iocoder.byzs.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRe
 import cn.iocoder.byzs.module.ai.service.model.AiChatRoleService;
 import cn.iocoder.byzs.module.ai.service.model.AiModelService;
 import cn.iocoder.byzs.module.ai.service.model.AiToolService;
+import cn.iocoder.byzs.module.ai.util.tts.WavHeader;
 import jakarta.annotation.Resource;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.ai.chat.messages.Message;
@@ -287,7 +288,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
                     byte[] processedAudio;
                     if (isFirstChunk.getAndSet(false)) {
                         // 仅首包添加WAV头
-                        processedAudio = addWavHeader(audioBytes, 16000, 16, 1);
+                        processedAudio = WavHeader.addWavHeader(audioBytes, 16000, 16, 1);
                         log.info("首包音频带WAV头,长度={} bytes", processedAudio.length);
                     } else {
                         // 后续包直接使用原始PCM数据
@@ -363,38 +364,6 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         return knowledgeSegments;
     }
 
-    private byte[] addWavHeader(byte[] pcmData, int sampleRate, int bitsPerSample, int channels) {
-        int byteRate = sampleRate * channels * bitsPerSample / 8;
-        int blockAlign = channels * bitsPerSample / 8;
-        int dataSize = pcmData.length;
-        int fileSize = 44 + dataSize; // 修正:WAV头标准大小为44字节
-
-        ByteBuffer buffer = ByteBuffer.allocate(fileSize); // 分配足够大小的缓冲区
-        buffer.order(ByteOrder.LITTLE_ENDIAN);
-
-        // RIFF chunk
-        buffer.put("RIFF".getBytes());
-        buffer.putInt(fileSize - 8); // 总文件大小 - 8
-        buffer.put("WAVE".getBytes());
-
-        // fmt subchunk
-        buffer.put("fmt ".getBytes());
-        buffer.putInt(16); // PCM格式子块大小固定为16
-        buffer.putShort((short) 1); // 线性PCM编码
-        buffer.putShort((short) channels);
-        buffer.putInt(sampleRate);
-        buffer.putInt(byteRate);
-        buffer.putShort((short) blockAlign);
-        buffer.putShort((short) bitsPerSample);
-
-        // data subchunk
-        buffer.put("data".getBytes());
-        buffer.putInt(dataSize);
-        buffer.put(pcmData); // 此时缓冲区有足够空间,不会溢出
-
-        return buffer.array();
-    }
-
     private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
                                List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments,
                                AiModelDO model, AiChatMessageSendReqVO sendReqVO) {

+ 39 - 0
byzs-module-ai/src/main/java/cn/iocoder/byzs/module/ai/util/tts/WavHeader.java

@@ -0,0 +1,39 @@
+package cn.iocoder.byzs.module.ai.util.tts;
+
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+
+public class WavHeader {
+    // 添加WAV文件头
+    public static byte[] addWavHeader(byte[] pcmData, int sampleRate, int bitsPerSample, int channels) {
+        int byteRate = sampleRate * channels * bitsPerSample / 8;
+        int blockAlign = channels * bitsPerSample / 8;
+        int dataSize = pcmData.length;
+        int fileSize = 44 + dataSize;
+
+        ByteBuffer buffer = ByteBuffer.allocate(fileSize);
+        buffer.order(ByteOrder.LITTLE_ENDIAN);
+
+        // RIFF chunk
+        buffer.put("RIFF".getBytes());
+        buffer.putInt(fileSize - 8);
+        buffer.put("WAVE".getBytes());
+
+        // fmt subchunk
+        buffer.put("fmt ".getBytes());
+        buffer.putInt(16);
+        buffer.putShort((short) 1);
+        buffer.putShort((short) channels);
+        buffer.putInt(sampleRate);
+        buffer.putInt(byteRate);
+        buffer.putShort((short) blockAlign);
+        buffer.putShort((short) bitsPerSample);
+
+        // data subchunk
+        buffer.put("data".getBytes());
+        buffer.putInt(dataSize);
+        buffer.put(pcmData);
+
+        return buffer.array();
+    }
+}

+ 1 - 0
byzs-module-infra/src/main/java/cn/iocoder/byzs/module/infra/framework/file/core/utils/FileTypeUtils.java

@@ -117,6 +117,7 @@ public class FileTypeUtils {
 
 
         // 设置 header 和 contentType
+        //attachment(下载)、inline(预览)
         response.setHeader("Content-Disposition", "attachment;filename=" + HttpUtils.encodeUtf8(filename));
         String contentType = getMineType(content, filename);
         response.setContentType(contentType);

+ 9 - 1
byzs-web/src/main/java/cn/iocoder/byzs/module/web/controller/admin/ai/WebAiController.java

@@ -19,12 +19,14 @@ import cn.iocoder.byzs.module.ai.service.chat.AiChatMessageService;
 import cn.iocoder.byzs.module.ai.service.image.AiImageService;
 import cn.iocoder.byzs.module.ai.service.model.AiChatRoleService;
 import cn.iocoder.byzs.module.web.controller.admin.ai.vo.WebAiChatRoleVO;
+import cn.iocoder.byzs.module.web.service.ai.WebAiServiceImpl;
 import io.swagger.v3.oas.annotations.Operation;
 import io.swagger.v3.oas.annotations.Parameter;
 import io.swagger.v3.oas.annotations.tags.Tag;
 import jakarta.annotation.Resource;
 import jakarta.annotation.security.PermitAll;
 import jakarta.validation.Valid;
+import jodd.util.StringUtil;
 import org.springframework.http.MediaType;
 import org.springframework.validation.annotation.Validated;
 import org.springframework.web.bind.annotation.*;
@@ -49,6 +51,8 @@ public class WebAiController {
     private AiImageService imageService;
     @Resource
     private AiChatRoleService chatRoleService;
+    @Resource
+    private WebAiServiceImpl webAiService;
     
     // ================ 智能问答 ================
 
@@ -68,7 +72,11 @@ public class WebAiController {
     @Operation(summary = "智能问答-发送消息(流式)", description = "流式返回,响应较快")
     @PostMapping(value = "/dialogue-send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
     public Flux<CommonResult<AiChatMessageSendRespVO>> sendChatMessageStream(@Valid @RequestBody AiChatMessageSendReqVO sendReqVO) {
-        return chatMessageService.sendChatMessageStream(sendReqVO, getLoginUserId());
+        if (StringUtil.isNotEmpty(sendReqVO.getContentAnswer())) {
+            return webAiService.sendSpecifiedAnswerStream(sendReqVO, getLoginUserId());
+        }else{
+            return chatMessageService.sendChatMessageStream(sendReqVO, getLoginUserId());
+        }
     }
 
     // ================ 绘图管理 ================

+ 177 - 0
byzs-web/src/main/java/cn/iocoder/byzs/module/web/service/ai/WebAiServiceImpl.java

@@ -0,0 +1,177 @@
+package cn.iocoder.byzs.module.web.service.ai;
+
+import cn.hutool.core.util.ObjUtil;
+import cn.iocoder.byzs.framework.common.pojo.CommonResult;
+import cn.iocoder.byzs.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
+import cn.iocoder.byzs.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO;
+import cn.iocoder.byzs.module.ai.dal.dataobject.chat.AiChatConversationDO;
+import cn.iocoder.byzs.module.ai.dal.dataobject.model.AiChatRoleDO;
+import cn.iocoder.byzs.module.ai.dal.dataobject.tts.AiTtsDO;
+import cn.iocoder.byzs.module.ai.dal.mysql.tts.AiTtsMapper;
+import cn.iocoder.byzs.module.ai.enums.ErrorCodeConstants;
+import cn.iocoder.byzs.module.ai.service.chat.AiChatConversationService;
+import cn.iocoder.byzs.module.ai.service.model.AiChatRoleService;
+import cn.iocoder.byzs.module.ai.util.tts.StreamTtsService;
+import cn.iocoder.byzs.module.ai.util.tts.WavHeader;
+import jakarta.annotation.Resource;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.beans.factory.ObjectProvider;
+import org.springframework.stereotype.Service;
+import org.springframework.validation.annotation.Validated;
+import reactor.core.publisher.Flux;
+
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.util.Base64;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ScheduledFuture;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+import static cn.iocoder.byzs.framework.common.exception.util.ServiceExceptionUtil.exception;
+import static cn.iocoder.byzs.module.ai.enums.ErrorCodeConstants.CHAT_CONVERSATION_NOT_EXISTS;
+import static cn.iocoder.byzs.framework.common.pojo.CommonResult.error;
+import static cn.iocoder.byzs.framework.common.pojo.CommonResult.success;
+
+/**
+ * webAi Service 实现类
+ *
+ * @author lyb
+ */
+@Service
+@Validated
+@Slf4j
+public class WebAiServiceImpl {
+
+    @Resource
+    private AiChatConversationService chatConversationService;
+
+    @Resource
+    private AiChatRoleService chatRoleService;
+
+    @Resource
+    private AiTtsMapper ttsMapper;
+
+    @Resource
+    private ObjectProvider<StreamTtsService> streamTtsServiceProvider;
+
+    public Flux<CommonResult<AiChatMessageSendRespVO>> sendSpecifiedAnswerStream(AiChatMessageSendReqVO sendReqVO, Long userId) {
+        // 1. 校验对话存在
+        AiChatConversationDO conversation = chatConversationService
+                .validateChatConversationExists(sendReqVO.getConversationId());
+        if (ObjUtil.notEqual(conversation.getUserId(), userId)) {
+            throw exception(CHAT_CONVERSATION_NOT_EXISTS);
+        }
+
+        // 2. 获取TTS配置
+        AiChatRoleDO chatRole = chatRoleService.getChatRole(conversation.getRoleId());
+        if (chatRole == null || chatRole.getTtsId() == null) {
+            throw exception(ErrorCodeConstants.TTS_NOT_EXISTS);
+        }
+        AiTtsDO aiTtsDO = ttsMapper.selectById(chatRole.getTtsId());
+        if (aiTtsDO == null) {
+            throw exception(ErrorCodeConstants.TTS_NOT_EXISTS);
+        }
+
+        // 3. 初始化TTS服务
+        StreamTtsService streamTtsService = streamTtsServiceProvider.getObject();
+        streamTtsService.startTts(aiTtsDO);
+
+        // 4. 处理指定回答内容
+        String contentAnswer = sendReqVO.getContentAnswer();
+        StringBuffer contentTTSBuffer = new StringBuffer(contentAnswer);
+        Pattern sentencePattern = Pattern.compile("[。!?;\n\r]");
+
+        ScheduledExecutorService scheduler = Executors.newSingleThreadScheduledExecutor();
+        AtomicReference<ScheduledFuture<?>> ttsTask = new AtomicReference<>();
+
+        // 5. 创建文本流
+        Flux<CommonResult<AiChatMessageSendRespVO>> textStream = Flux.just(success(
+                new AiChatMessageSendRespVO()
+                        .setEventType("TEXT")
+                        .setReceive(new AiChatMessageSendRespVO.Message().setContent(contentAnswer))
+        )).doOnComplete(() -> {
+            processRemainingText(streamTtsService, contentTTSBuffer);
+            if (ttsTask.get() != null) {
+                ttsTask.get().cancel(false);
+            }
+            scheduler.shutdown();
+        }).doOnError(throwable -> {
+            streamTtsService.stopTts();
+        }).doFinally(signalType -> {
+            streamTtsService.stopTts();
+        });
+
+        // 6. 创建音频流
+        Flux<CommonResult<AiChatMessageSendRespVO>> audioStream = Flux.create(sink -> {
+            AtomicBoolean isFirstChunk = new AtomicBoolean(true);
+            streamTtsService.setAudioDataCallback(audioBytes -> {
+                try {
+                    byte[] processedAudio;
+                    if (isFirstChunk.getAndSet(false)) {
+                        processedAudio = WavHeader.addWavHeader(audioBytes, 16000, 16, 1);
+                        log.info("首包音频带WAV头,长度={} bytes", processedAudio.length);
+                    } else {
+                        processedAudio = audioBytes;
+                    }
+                    String base64Audio = Base64.getEncoder().encodeToString(processedAudio);
+                    AiChatMessageSendRespVO audioResp = new AiChatMessageSendRespVO();
+                    audioResp.setEventType("AUDIO");
+                    audioResp.setAudioData(base64Audio);
+                    sink.next(success(audioResp));
+                } catch (Exception e) {
+                    log.error("[TTS处理异常] 音频编码失败", e);
+                    sink.error(new RuntimeException("TTS音频处理失败: " + e.getMessage(), e));
+                }
+            });
+            streamTtsService.setOnCompleteCallback(sink::complete);
+
+            // 立即处理文本
+            ttsTask.set(scheduler.schedule(() -> {
+                Matcher matcher = sentencePattern.matcher(contentTTSBuffer);
+                if (matcher.find()) {
+                    processCompleteSentence(streamTtsService, contentTTSBuffer, matcher);
+                } else if (!contentTTSBuffer.isEmpty()) {
+                    processCompleteSentence(streamTtsService, contentTTSBuffer, contentTTSBuffer.length());
+                }
+            }, 100, TimeUnit.MILLISECONDS));
+        });
+
+        // 7. 合并流并返回
+        return Flux.merge(textStream, audioStream)
+                .doFinally(signalType -> {
+                    streamTtsService.setAudioDataCallback(null);
+                    streamTtsService.setOnCompleteCallback(null);
+                    scheduler.shutdownNow();
+                });
+    }
+
+    // 处理完整句子
+    private void processCompleteSentence(StreamTtsService streamTtsService, StringBuffer buffer, Matcher matcher) {
+        String sentence = buffer.substring(0, matcher.end());
+        streamTtsService.sendText(sentence);
+        buffer.delete(0, matcher.end());
+        log.info("TTS合成完整句: {}", sentence);
+    }
+
+    // 处理指定长度文本
+    private void processCompleteSentence(StreamTtsService streamTtsService, StringBuffer buffer, int length) {
+        String sentence = buffer.substring(0, length);
+        streamTtsService.sendText(sentence);
+        buffer.delete(0, length);
+        log.info("TTS合成长文本: {}", sentence);
+    }
+
+    // 处理剩余文本
+    private void processRemainingText(StreamTtsService streamTtsService, StringBuffer buffer) {
+        if (!buffer.isEmpty()) {
+            streamTtsService.sendText(buffer.toString());
+            buffer.setLength(0);
+        }
+    }
+
+}