|
|
@@ -42,10 +42,22 @@ import org.springframework.ai.chat.prompt.ChatOptions;
|
|
|
import org.springframework.ai.chat.prompt.Prompt;
|
|
|
import org.springframework.stereotype.Service;
|
|
|
import org.springframework.transaction.annotation.Transactional;
|
|
|
+import reactor.core.publisher.ConnectableFlux;
|
|
|
+import reactor.core.publisher.DirectProcessor;
|
|
|
import reactor.core.publisher.Flux;
|
|
|
+import reactor.core.publisher.FluxSink;
|
|
|
|
|
|
+import java.nio.ByteBuffer;
|
|
|
+import java.nio.ByteOrder;
|
|
|
import java.time.LocalDateTime;
|
|
|
import java.util.*;
|
|
|
+import java.util.concurrent.Executors;
|
|
|
+import java.util.concurrent.ScheduledExecutorService;
|
|
|
+import java.util.concurrent.ScheduledFuture;
|
|
|
+import java.util.concurrent.TimeUnit;
|
|
|
+import java.util.concurrent.atomic.AtomicReference;
|
|
|
+import java.util.regex.Matcher;
|
|
|
+import java.util.regex.Pattern;
|
|
|
import java.util.stream.Collectors;
|
|
|
|
|
|
import static cn.iocoder.byzs.framework.common.exception.util.ServiceExceptionUtil.exception;
|
|
|
@@ -170,12 +182,25 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO);
|
|
|
Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
|
|
|
|
|
|
+ // 创建一个Processor用于合并文本和音频流
|
|
|
+ // 创建一个共享的响应对象
|
|
|
+ AtomicReference<CommonResult<AiChatMessageSendRespVO>> sharedResponse = new AtomicReference<>();
|
|
|
+
|
|
|
+ DirectProcessor<CommonResult<AiChatMessageSendRespVO>> processor = DirectProcessor.create();
|
|
|
+ FluxSink<CommonResult<AiChatMessageSendRespVO>> sink = processor.sink();
|
|
|
+
|
|
|
// 4.3 初始化TTS服务
|
|
|
streamTtsService.startTts();
|
|
|
|
|
|
// 4.4 流式返回并处理TTS
|
|
|
StringBuffer contentBuffer = new StringBuffer();
|
|
|
- return streamResponse.map(chunk -> {
|
|
|
+ // 添加句子结束符正则表达式
|
|
|
+ Pattern sentencePattern = Pattern.compile("[。!?;.\n\r]"); // 增加换行符支持
|
|
|
+
|
|
|
+ ScheduledExecutorService scheduler = Executors.newSingleThreadScheduledExecutor();
|
|
|
+ AtomicReference<ScheduledFuture<?>> ttsTask = new AtomicReference<>();
|
|
|
+
|
|
|
+ Flux<CommonResult<AiChatMessageSendRespVO>> textStream = streamResponse.map(chunk -> {
|
|
|
// 处理知识库的返回,只有首次才有
|
|
|
List<AiChatMessageRespVO.KnowledgeSegment> segments = null;
|
|
|
if (StrUtil.isEmpty(contentBuffer)) {
|
|
|
@@ -194,19 +219,45 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
contentBuffer.append(newContent);
|
|
|
|
|
|
// 发送新内容到TTS服务进行语音合成
|
|
|
- if (StrUtil.isNotBlank(newContent)) {
|
|
|
- streamTtsService.sendText(newContent);
|
|
|
+ if (ttsTask.get() != null) {
|
|
|
+ ttsTask.get().cancel(false); // 取消之前的延迟任务
|
|
|
}
|
|
|
- return success(new AiChatMessageSendRespVO()
|
|
|
+ // 延迟500ms执行,合并短时间内到达的文本片段
|
|
|
+ ttsTask.set(scheduler.schedule(() -> {
|
|
|
+ Matcher matcher = sentencePattern.matcher(contentBuffer);
|
|
|
+ if (matcher.find()) {
|
|
|
+ processCompleteSentence(contentBuffer, matcher);
|
|
|
+ } else if (contentBuffer.length() > 20) { // 最长20字未结束也处理
|
|
|
+ processCompleteSentence(contentBuffer, contentBuffer.length());
|
|
|
+ }
|
|
|
+ }, 500, TimeUnit.MILLISECONDS));
|
|
|
+
|
|
|
+ CommonResult<AiChatMessageSendRespVO> result = success(new AiChatMessageSendRespVO()
|
|
|
+ .setEventType("TEXT")
|
|
|
.setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
|
|
|
.setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class)
|
|
|
.setContent(newContent).setSegments(segments)));
|
|
|
+ sharedResponse.set(result);
|
|
|
+ return result;
|
|
|
}).doOnComplete(() -> {
|
|
|
+ if (contentBuffer.length() > 0) {
|
|
|
+ streamTtsService.sendText(contentBuffer.toString());
|
|
|
+ contentBuffer.setLength(0);
|
|
|
+ }
|
|
|
// 忽略租户,因为 Flux 异步无法透传租户
|
|
|
TenantUtils.executeIgnore(() -> chatMessageMapper.updateById(
|
|
|
new AiChatMessageDO().setId(assistantMessage.getId()).setContent(contentBuffer.toString())));
|
|
|
+
|
|
|
+
|
|
|
+ if (ttsTask.get() != null) {
|
|
|
+ ttsTask.get().cancel(false);
|
|
|
+ }
|
|
|
+ processRemainingText(contentBuffer); // 处理剩余文本
|
|
|
+ scheduler.shutdown(); // 关闭调度器
|
|
|
+
|
|
|
// 通知TTS服务文本发送完成
|
|
|
streamTtsService.stopTts();
|
|
|
+ sink.complete(); // 完成流
|
|
|
}).doOnError(throwable -> {
|
|
|
log.error("[sendChatMessageStream][userId({}) sendReqVO({}) 发生异常]", userId, sendReqVO, throwable);
|
|
|
// 忽略租户,因为 Flux 异步无法透传租户
|
|
|
@@ -214,7 +265,61 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
new AiChatMessageDO().setId(assistantMessage.getId()).setContent(throwable.getMessage())));
|
|
|
// 发生错误时停止TTS服务
|
|
|
streamTtsService.stopTts();
|
|
|
+ // ==== 添加回调清理 ====
|
|
|
+ streamTtsService.setAudioDataCallback(null);
|
|
|
+ // =====================
|
|
|
+ sink.error(throwable); // 传递错误
|
|
|
+ })
|
|
|
+ // ==== 添加finally块清理 ====
|
|
|
+ .doFinally(signalType -> {
|
|
|
+ streamTtsService.setAudioDataCallback(null);
|
|
|
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR)));
|
|
|
+
|
|
|
+ // 创建音频流
|
|
|
+ Flux<CommonResult<AiChatMessageSendRespVO>> audioStream = Flux.create(sink2 -> {
|
|
|
+ streamTtsService.setAudioDataCallback(audioBytes -> {
|
|
|
+ try {
|
|
|
+ // 确保TTS输出WAV格式(带文件头)
|
|
|
+ byte[] wavAudioWithHeader = addWavHeader(audioBytes, 24000, 16, 1); // 修改为24kHz(匹配StreamTtsService设置)
|
|
|
+ String base64Audio = Base64.getEncoder().encodeToString(wavAudioWithHeader);
|
|
|
+ AiChatMessageSendRespVO audioResp = new AiChatMessageSendRespVO();
|
|
|
+ audioResp.setEventType("AUDIO");
|
|
|
+ audioResp.setAudioData(base64Audio);
|
|
|
+ sink2.next(success(audioResp));
|
|
|
+ } catch (Exception e) {
|
|
|
+ log.error("[------------][userId({}) sendReqVO({}) <UNK> 发生异常]", userId, sendReqVO, e);
|
|
|
+ sink2.error(e);
|
|
|
+ }
|
|
|
+ });
|
|
|
+ });
|
|
|
+
|
|
|
+ // 合并文本流和音频流,使用mergeWith而非mergeSequential
|
|
|
+ return textStream.mergeWith(audioStream);
|
|
|
+ }
|
|
|
+
|
|
|
+ // 处理完整句子
|
|
|
+ private void processCompleteSentence(StringBuffer buffer, Matcher matcher) {
|
|
|
+ String sentence = buffer.substring(0, matcher.end());
|
|
|
+ streamTtsService.sendText(sentence);
|
|
|
+ buffer.delete(0, matcher.end());
|
|
|
+ System.out.println("TTS合成完整句: " + sentence);
|
|
|
+ }
|
|
|
+
|
|
|
+ // 处理指定长度文本
|
|
|
+ private void processCompleteSentence(StringBuffer buffer, int length) {
|
|
|
+ String sentence = buffer.substring(0, length);
|
|
|
+ streamTtsService.sendText(sentence);
|
|
|
+ buffer.delete(0, length);
|
|
|
+ System.out.println("TTS合成长文本: " + sentence);
|
|
|
+ }
|
|
|
+
|
|
|
+ // 处理剩余文本
|
|
|
+ private void processRemainingText(StringBuffer buffer) {
|
|
|
+ if (buffer.length() > 0) {
|
|
|
+ streamTtsService.sendText(buffer.toString());
|
|
|
+ buffer.setLength(0);
|
|
|
+ System.out.println("TTS合成剩余文本: " + buffer.toString());
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
private List<AiKnowledgeSegmentSearchRespBO> recallKnowledgeSegment(String content,
|
|
|
@@ -237,6 +342,38 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
return knowledgeSegments;
|
|
|
}
|
|
|
|
|
|
+ private byte[] addWavHeader(byte[] pcmData, int sampleRate, int bitsPerSample, int channels) {
|
|
|
+ int byteRate = sampleRate * channels * bitsPerSample / 8;
|
|
|
+ int blockAlign = channels * bitsPerSample / 8;
|
|
|
+ int dataSize = pcmData.length;
|
|
|
+ int fileSize = 44 + dataSize; // 修正:WAV头标准大小为44字节
|
|
|
+
|
|
|
+ ByteBuffer buffer = ByteBuffer.allocate(fileSize); // 分配足够大小的缓冲区
|
|
|
+ buffer.order(ByteOrder.LITTLE_ENDIAN);
|
|
|
+
|
|
|
+ // RIFF chunk
|
|
|
+ buffer.put("RIFF".getBytes());
|
|
|
+ buffer.putInt(fileSize - 8); // 总文件大小 - 8
|
|
|
+ buffer.put("WAVE".getBytes());
|
|
|
+
|
|
|
+ // fmt subchunk
|
|
|
+ buffer.put("fmt ".getBytes());
|
|
|
+ buffer.putInt(16); // PCM格式子块大小固定为16
|
|
|
+ buffer.putShort((short) 1); // 线性PCM编码
|
|
|
+ buffer.putShort((short) channels);
|
|
|
+ buffer.putInt(sampleRate);
|
|
|
+ buffer.putInt(byteRate);
|
|
|
+ buffer.putShort((short) blockAlign);
|
|
|
+ buffer.putShort((short) bitsPerSample);
|
|
|
+
|
|
|
+ // data subchunk
|
|
|
+ buffer.put("data".getBytes());
|
|
|
+ buffer.putInt(dataSize);
|
|
|
+ buffer.put(pcmData); // 此时缓冲区有足够空间,不会溢出
|
|
|
+
|
|
|
+ return buffer.array();
|
|
|
+ }
|
|
|
+
|
|
|
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
|
|
|
List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments,
|
|
|
AiModelDO model, AiChatMessageSendReqVO sendReqVO) {
|