|
|
@@ -1,337 +0,0 @@
|
|
|
-package cn.iocoder.byzs.module.ai.controller.admin.speech;
|
|
|
-
|
|
|
-import com.alibaba.nls.client.AccessToken;
|
|
|
-import com.alibaba.nls.client.protocol.NlsClient;
|
|
|
-import com.alibaba.nls.client.protocol.InputFormatEnum;
|
|
|
-import com.alibaba.nls.client.protocol.SampleRateEnum;
|
|
|
-import com.alibaba.nls.client.protocol.asr.SpeechTranscriber;
|
|
|
-import com.alibaba.nls.client.protocol.asr.SpeechTranscriberListener;
|
|
|
-import com.alibaba.nls.client.protocol.asr.SpeechTranscriberResponse;
|
|
|
-import org.slf4j.Logger;
|
|
|
-import org.slf4j.LoggerFactory;
|
|
|
-import org.springframework.beans.factory.annotation.Value;
|
|
|
-import org.springframework.web.bind.annotation.*;
|
|
|
-import org.springframework.web.multipart.MultipartFile;
|
|
|
-import org.springframework.web.socket.CloseStatus;
|
|
|
-import org.springframework.web.socket.TextMessage;
|
|
|
-import org.springframework.web.socket.WebSocketSession;
|
|
|
-import org.springframework.web.socket.handler.TextWebSocketHandler;
|
|
|
-
|
|
|
-import jakarta.servlet.http.HttpServletRequest;
|
|
|
-import org.springframework.web.context.request.RequestContextHolder;
|
|
|
-import org.springframework.web.context.request.ServletRequestAttributes;
|
|
|
-import org.springframework.web.multipart.MultipartFile;
|
|
|
-import org.springframework.web.multipart.MultipartHttpServletRequest;
|
|
|
-import java.util.Map;
|
|
|
-import java.util.concurrent.ConcurrentHashMap;
|
|
|
-import java.util.concurrent.CountDownLatch;
|
|
|
-import java.util.concurrent.TimeUnit;
|
|
|
-
|
|
|
-@RestController
|
|
|
-@RequestMapping("/admin/ai/speech")
|
|
|
-public class RealTimeSpeechController {
|
|
|
-
|
|
|
- private static final Logger logger = LoggerFactory.getLogger(RealTimeSpeechController.class);
|
|
|
-
|
|
|
- @Value("${ai.aliyun.app-key:4SUOF4LfaU7FekyW}")
|
|
|
- private String appKey;
|
|
|
-
|
|
|
- @Value("${ai.aliyun.access-key-id:LTAI5tQhMPLXtSgXiPiWbw6D}")
|
|
|
- private String accessKeyId;
|
|
|
-
|
|
|
- @Value("${ai.aliyun.access-key-secret:HCXpFYjl4swk0qwfIKa9s2bXx0AWcG}")
|
|
|
- private String accessKeySecret;
|
|
|
-
|
|
|
- @Value("${ai.aliyun.nls-gateway-url:wss://nls-gateway-cn-shanghai.aliyuncs.com/ws/v1}")
|
|
|
- private String nlsGatewayUrl;
|
|
|
-
|
|
|
- // 存储每个会话的识别结果
|
|
|
- private final Map<String, StringBuilder> transcriptionResults = new ConcurrentHashMap<>();
|
|
|
-
|
|
|
- // 存储每个会话的NlsClient
|
|
|
- private final Map<String, NlsClient> nlsClients = new ConcurrentHashMap<>();
|
|
|
-
|
|
|
- // 存储每个会话的SpeechTranscriber
|
|
|
- private final Map<String, SpeechTranscriber> transcribers = new ConcurrentHashMap<>();
|
|
|
-
|
|
|
- // 存储每个会话的结束信号
|
|
|
- private final Map<String, CountDownLatch> latches = new ConcurrentHashMap<>();
|
|
|
-
|
|
|
- // 存储WebSocket会话,用于实时返回中间结果
|
|
|
- private final Map<String, WebSocketSession> webSocketSessions = new ConcurrentHashMap<>();
|
|
|
-
|
|
|
- // WebSocket处理器
|
|
|
- public class SpeechWebSocketHandler extends TextWebSocketHandler {
|
|
|
- private final Map<String, WebSocketSession> webSocketSessions;
|
|
|
- private final Logger logger = LoggerFactory.getLogger(SpeechWebSocketHandler.class);
|
|
|
-
|
|
|
- public SpeechWebSocketHandler(Map<String, WebSocketSession> webSocketSessions) {
|
|
|
- this.webSocketSessions = webSocketSessions;
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public void afterConnectionEstablished(WebSocketSession session) throws Exception {
|
|
|
- // 从会话参数中获取sessionId
|
|
|
- String sessionId = session.getUri().getQuery().split("=")[1];
|
|
|
- webSocketSessions.put(sessionId, session);
|
|
|
- logger.info("WebSocket连接建立,sessionId: {}", sessionId);
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
|
|
|
- // 处理客户端消息
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
|
|
|
- // 从会话参数中获取sessionId
|
|
|
- String sessionId = session.getUri().getQuery().split("=")[1];
|
|
|
- webSocketSessions.remove(sessionId);
|
|
|
- logger.info("WebSocket连接关闭,sessionId: {}", sessionId);
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * 开始语音识别会话
|
|
|
- */
|
|
|
- @PostMapping("/start")
|
|
|
- public Map<String, Object> startRecognition() {
|
|
|
- // 生成唯一的会话ID
|
|
|
- String sessionId = java.util.UUID.randomUUID().toString();
|
|
|
- try {
|
|
|
- // 清理旧的会话资源
|
|
|
- cleanupSession(sessionId);
|
|
|
-
|
|
|
- // 初始化NlsClient
|
|
|
- AccessToken accessToken = new AccessToken(accessKeyId, accessKeySecret);
|
|
|
- accessToken.apply();
|
|
|
- NlsClient client = new NlsClient(nlsGatewayUrl, accessToken.getToken());
|
|
|
- nlsClients.put(sessionId, client);
|
|
|
-
|
|
|
- // 初始化识别结果
|
|
|
- transcriptionResults.put(sessionId, new StringBuilder());
|
|
|
-
|
|
|
- // 初始化结束信号
|
|
|
- latches.put(sessionId, new CountDownLatch(1));
|
|
|
-
|
|
|
- // 创建SpeechTranscriber
|
|
|
- SpeechTranscriber transcriber = new SpeechTranscriber(client, getTranscriberListener(sessionId));
|
|
|
- transcriber.setAppKey(appKey);
|
|
|
- transcriber.setFormat(InputFormatEnum.PCM);
|
|
|
- transcriber.setSampleRate(SampleRateEnum.SAMPLE_RATE_16K);
|
|
|
- transcriber.setEnableIntermediateResult(true);
|
|
|
- transcriber.setEnablePunctuation(true);
|
|
|
- transcriber.setEnableITN(false);
|
|
|
- transcriber.start();
|
|
|
-
|
|
|
- transcribers.put(sessionId, transcriber);
|
|
|
-
|
|
|
- // 启动心跳线程,避免WebSocket会话超时
|
|
|
- startHeartbeatThread(sessionId, transcriber);
|
|
|
-
|
|
|
- return Map.of("success", true, "message", "语音识别会话已开始", "sessionId", sessionId);
|
|
|
- } catch (Exception e) {
|
|
|
- logger.error("开始语音识别失败", e);
|
|
|
- cleanupSession(sessionId);
|
|
|
- return Map.of("success", false, "message", "开始语音识别失败: " + e.getMessage());
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * 启动心跳线程,定期发送空数据保持WebSocket连接
|
|
|
- */
|
|
|
- private void startHeartbeatThread(String sessionId, SpeechTranscriber transcriber) {
|
|
|
- Thread heartbeatThread = new Thread(() -> {
|
|
|
- try {
|
|
|
- while (transcribers.containsKey(sessionId)) {
|
|
|
- // 发送空数据保持连接
|
|
|
- transcriber.send(new byte[0], 0);
|
|
|
- Thread.sleep(5000); // 每5秒发送一次心跳
|
|
|
- }
|
|
|
- } catch (Exception e) {
|
|
|
- logger.error("心跳线程异常", e);
|
|
|
- }
|
|
|
- });
|
|
|
- heartbeatThread.setDaemon(true);
|
|
|
- heartbeatThread.start();
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * 清理会话资源
|
|
|
- */
|
|
|
- private void cleanupSession(String sessionId) {
|
|
|
- try {
|
|
|
- SpeechTranscriber transcriber = transcribers.remove(sessionId);
|
|
|
- if (transcriber != null) {
|
|
|
- transcriber.close();
|
|
|
- }
|
|
|
- NlsClient client = nlsClients.remove(sessionId);
|
|
|
- if (client != null) {
|
|
|
- client.shutdown();
|
|
|
- }
|
|
|
- transcriptionResults.remove(sessionId);
|
|
|
- latches.remove(sessionId);
|
|
|
- } catch (Exception e) {
|
|
|
- logger.error("清理会话资源失败", e);
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * 接收音频数据并发送到阿里云
|
|
|
- */
|
|
|
- @PostMapping("/stream")
|
|
|
- public Map<String, Object> streamAudio(HttpServletRequest request) {
|
|
|
- try {
|
|
|
- // 从请求参数中获取sessionId
|
|
|
- String sessionId = null;
|
|
|
-
|
|
|
- // 首先尝试从URL参数中获取
|
|
|
- sessionId = request.getParameter("sessionId");
|
|
|
-
|
|
|
- // 如果URL参数中没有,尝试从multipart表单中获取
|
|
|
- if (sessionId == null || sessionId.isEmpty()) {
|
|
|
- if (request instanceof MultipartHttpServletRequest) {
|
|
|
- MultipartHttpServletRequest multipartRequest = (MultipartHttpServletRequest) request;
|
|
|
- sessionId = multipartRequest.getParameter("sessionId");
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- if (sessionId == null || sessionId.isEmpty()) {
|
|
|
- logger.error("sessionId参数缺失,请求参数: {}", request.getParameterMap());
|
|
|
- return Map.of("success", false, "message", "请求参数缺失:sessionId");
|
|
|
- }
|
|
|
-
|
|
|
- // 获取音频文件
|
|
|
- MultipartFile audioFile = null;
|
|
|
- if (request instanceof MultipartHttpServletRequest) {
|
|
|
- MultipartHttpServletRequest multipartRequest = (MultipartHttpServletRequest) request;
|
|
|
- audioFile = multipartRequest.getFile("audio");
|
|
|
- }
|
|
|
-
|
|
|
- SpeechTranscriber transcriber = transcribers.get(sessionId);
|
|
|
- if (transcriber == null) {
|
|
|
- logger.error("会话不存在,sessionId: {}", sessionId);
|
|
|
- return Map.of("success", false, "message", "语音识别会话未开始");
|
|
|
- }
|
|
|
-
|
|
|
- if (audioFile == null || audioFile.isEmpty()) {
|
|
|
- // 忽略空数据,避免发送空数据到阿里云
|
|
|
- logger.info("接收到空音频数据,忽略处理");
|
|
|
- return Map.of("success", true, "message", "音频数据已接收");
|
|
|
- }
|
|
|
-
|
|
|
- byte[] audioData = audioFile.getBytes();
|
|
|
- if (audioData.length > 0) {
|
|
|
- transcriber.send(audioData, audioData.length);
|
|
|
- } else {
|
|
|
- // 忽略空数据,避免发送空数据到阿里云
|
|
|
- logger.info("接收到空音频数据,忽略处理");
|
|
|
- }
|
|
|
-
|
|
|
- return Map.of("success", true, "message", "音频数据已接收");
|
|
|
- } catch (Exception e) {
|
|
|
- logger.error("处理音频数据失败", e);
|
|
|
- return Map.of("success", false, "message", "处理音频数据失败: " + e.getMessage());
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * 结束语音识别会话并返回结果
|
|
|
- */
|
|
|
- @PostMapping("/stop")
|
|
|
- public Map<String, Object> stopRecognition(@RequestParam("sessionId") String sessionId) {
|
|
|
- try {
|
|
|
- SpeechTranscriber transcriber = transcribers.get(sessionId);
|
|
|
- CountDownLatch latch = latches.get(sessionId);
|
|
|
-
|
|
|
- if (transcriber != null) {
|
|
|
- try {
|
|
|
- transcriber.stop();
|
|
|
- // 等待识别完成
|
|
|
- if (latch != null) {
|
|
|
- latch.await(5, TimeUnit.SECONDS);
|
|
|
- }
|
|
|
- } catch (Exception e) {
|
|
|
- logger.error("停止transcriber失败", e);
|
|
|
- // 继续执行,确保资源被清理
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- // 获取识别结果
|
|
|
- StringBuilder result = transcriptionResults.get(sessionId);
|
|
|
- String finalResult = result != null ? result.toString() : "";
|
|
|
-
|
|
|
- // 清理资源
|
|
|
- cleanupSession(sessionId);
|
|
|
-
|
|
|
- return Map.of("success", true, "result", finalResult);
|
|
|
- } catch (Exception e) {
|
|
|
- logger.error("停止语音识别失败", e);
|
|
|
- // 确保资源被清理
|
|
|
- cleanupSession(sessionId);
|
|
|
- return Map.of("success", false, "message", "停止语音识别失败: " + e.getMessage());
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * 获取语音识别监听器
|
|
|
- */
|
|
|
- private SpeechTranscriberListener getTranscriberListener(String sessionId) {
|
|
|
- return new SpeechTranscriberListener() {
|
|
|
- @Override
|
|
|
- public void onTranscriptionResultChange(SpeechTranscriberResponse response) {
|
|
|
- String result = response.getTransSentenceText();
|
|
|
- logger.info("实时识别中间结果: " + result);
|
|
|
-
|
|
|
- // 通过WebSocket实时返回中间结果
|
|
|
- WebSocketSession session = webSocketSessions.get(sessionId);
|
|
|
- if (session != null && session.isOpen()) {
|
|
|
- try {
|
|
|
- session.sendMessage(new TextMessage("{\"type\":\"intermediate\",\"result\":\"" + result + "\"}"));
|
|
|
- } catch (Exception e) {
|
|
|
- logger.error("发送WebSocket消息失败", e);
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public void onTranscriberStart(SpeechTranscriberResponse response) {
|
|
|
- logger.info("语音识别会话开始, task_id: " + response.getTaskId());
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public void onSentenceBegin(SpeechTranscriberResponse response) {
|
|
|
- logger.info("开始识别新句子");
|
|
|
-
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public void onSentenceEnd(SpeechTranscriberResponse response) {
|
|
|
- logger.info("句子识别完成, 结果: " + response.getTransSentenceText());
|
|
|
- logger.info("置信度: " + response.getConfidence() + ", 开始时间: " + response.getSentenceBeginTime() + ", 处理时长: " + response.getTransSentenceTime() + "ms");
|
|
|
- StringBuilder result = transcriptionResults.get(sessionId);
|
|
|
- if (result != null) {
|
|
|
- result.append(response.getTransSentenceText());
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public void onTranscriptionComplete(SpeechTranscriberResponse response) {
|
|
|
- logger.info("语音识别会话完成");
|
|
|
- CountDownLatch latch = latches.get(sessionId);
|
|
|
- if (latch != null) {
|
|
|
- latch.countDown();
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public void onFail(SpeechTranscriberResponse response) {
|
|
|
- logger.error("语音识别失败: " + response.getStatusText() + ", 状态码: " + response.getStatus());
|
|
|
- CountDownLatch latch = latches.get(sessionId);
|
|
|
- if (latch != null) {
|
|
|
- latch.countDown();
|
|
|
- }
|
|
|
- }
|
|
|
- };
|
|
|
- }
|
|
|
-}
|