|
|
@@ -0,0 +1,337 @@
|
|
|
+package cn.iocoder.byzs.module.ai.controller.admin.speech;
|
|
|
+
|
|
|
+import com.alibaba.nls.client.AccessToken;
|
|
|
+import com.alibaba.nls.client.protocol.NlsClient;
|
|
|
+import com.alibaba.nls.client.protocol.InputFormatEnum;
|
|
|
+import com.alibaba.nls.client.protocol.SampleRateEnum;
|
|
|
+import com.alibaba.nls.client.protocol.asr.SpeechTranscriber;
|
|
|
+import com.alibaba.nls.client.protocol.asr.SpeechTranscriberListener;
|
|
|
+import com.alibaba.nls.client.protocol.asr.SpeechTranscriberResponse;
|
|
|
+import org.slf4j.Logger;
|
|
|
+import org.slf4j.LoggerFactory;
|
|
|
+import org.springframework.beans.factory.annotation.Value;
|
|
|
+import org.springframework.web.bind.annotation.*;
|
|
|
+import org.springframework.web.multipart.MultipartFile;
|
|
|
+import org.springframework.web.socket.CloseStatus;
|
|
|
+import org.springframework.web.socket.TextMessage;
|
|
|
+import org.springframework.web.socket.WebSocketSession;
|
|
|
+import org.springframework.web.socket.handler.TextWebSocketHandler;
|
|
|
+
|
|
|
+import jakarta.servlet.http.HttpServletRequest;
|
|
|
+import org.springframework.web.context.request.RequestContextHolder;
|
|
|
+import org.springframework.web.context.request.ServletRequestAttributes;
|
|
|
+import org.springframework.web.multipart.MultipartFile;
|
|
|
+import org.springframework.web.multipart.MultipartHttpServletRequest;
|
|
|
+import java.util.Map;
|
|
|
+import java.util.concurrent.ConcurrentHashMap;
|
|
|
+import java.util.concurrent.CountDownLatch;
|
|
|
+import java.util.concurrent.TimeUnit;
|
|
|
+
|
|
|
+@RestController
|
|
|
+@RequestMapping("/admin/ai/speech")
|
|
|
+public class RealTimeSpeechController {
|
|
|
+
|
|
|
+ private static final Logger logger = LoggerFactory.getLogger(RealTimeSpeechController.class);
|
|
|
+
|
|
|
+ @Value("${ai.aliyun.app-key:4SUOF4LfaU7FekyW}")
|
|
|
+ private String appKey;
|
|
|
+
|
|
|
+ @Value("${ai.aliyun.access-key-id:LTAI5tQhMPLXtSgXiPiWbw6D}")
|
|
|
+ private String accessKeyId;
|
|
|
+
|
|
|
+ @Value("${ai.aliyun.access-key-secret:HCXpFYjl4swk0qwfIKa9s2bXx0AWcG}")
|
|
|
+ private String accessKeySecret;
|
|
|
+
|
|
|
+ @Value("${ai.aliyun.nls-gateway-url:wss://nls-gateway-cn-shanghai.aliyuncs.com/ws/v1}")
|
|
|
+ private String nlsGatewayUrl;
|
|
|
+
|
|
|
+ // 存储每个会话的识别结果
|
|
|
+ private final Map<String, StringBuilder> transcriptionResults = new ConcurrentHashMap<>();
|
|
|
+
|
|
|
+ // 存储每个会话的NlsClient
|
|
|
+ private final Map<String, NlsClient> nlsClients = new ConcurrentHashMap<>();
|
|
|
+
|
|
|
+ // 存储每个会话的SpeechTranscriber
|
|
|
+ private final Map<String, SpeechTranscriber> transcribers = new ConcurrentHashMap<>();
|
|
|
+
|
|
|
+ // 存储每个会话的结束信号
|
|
|
+ private final Map<String, CountDownLatch> latches = new ConcurrentHashMap<>();
|
|
|
+
|
|
|
+ // 存储WebSocket会话,用于实时返回中间结果
|
|
|
+ private final Map<String, WebSocketSession> webSocketSessions = new ConcurrentHashMap<>();
|
|
|
+
|
|
|
+ // WebSocket处理器
|
|
|
+ public class SpeechWebSocketHandler extends TextWebSocketHandler {
|
|
|
+ private final Map<String, WebSocketSession> webSocketSessions;
|
|
|
+ private final Logger logger = LoggerFactory.getLogger(SpeechWebSocketHandler.class);
|
|
|
+
|
|
|
+ public SpeechWebSocketHandler(Map<String, WebSocketSession> webSocketSessions) {
|
|
|
+ this.webSocketSessions = webSocketSessions;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void afterConnectionEstablished(WebSocketSession session) throws Exception {
|
|
|
+ // 从会话参数中获取sessionId
|
|
|
+ String sessionId = session.getUri().getQuery().split("=")[1];
|
|
|
+ webSocketSessions.put(sessionId, session);
|
|
|
+ logger.info("WebSocket连接建立,sessionId: {}", sessionId);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
|
|
|
+ // 处理客户端消息
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
|
|
|
+ // 从会话参数中获取sessionId
|
|
|
+ String sessionId = session.getUri().getQuery().split("=")[1];
|
|
|
+ webSocketSessions.remove(sessionId);
|
|
|
+ logger.info("WebSocket连接关闭,sessionId: {}", sessionId);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 开始语音识别会话
|
|
|
+ */
|
|
|
+ @PostMapping("/start")
|
|
|
+ public Map<String, Object> startRecognition() {
|
|
|
+ // 生成唯一的会话ID
|
|
|
+ String sessionId = java.util.UUID.randomUUID().toString();
|
|
|
+ try {
|
|
|
+ // 清理旧的会话资源
|
|
|
+ cleanupSession(sessionId);
|
|
|
+
|
|
|
+ // 初始化NlsClient
|
|
|
+ AccessToken accessToken = new AccessToken(accessKeyId, accessKeySecret);
|
|
|
+ accessToken.apply();
|
|
|
+ NlsClient client = new NlsClient(nlsGatewayUrl, accessToken.getToken());
|
|
|
+ nlsClients.put(sessionId, client);
|
|
|
+
|
|
|
+ // 初始化识别结果
|
|
|
+ transcriptionResults.put(sessionId, new StringBuilder());
|
|
|
+
|
|
|
+ // 初始化结束信号
|
|
|
+ latches.put(sessionId, new CountDownLatch(1));
|
|
|
+
|
|
|
+ // 创建SpeechTranscriber
|
|
|
+ SpeechTranscriber transcriber = new SpeechTranscriber(client, getTranscriberListener(sessionId));
|
|
|
+ transcriber.setAppKey(appKey);
|
|
|
+ transcriber.setFormat(InputFormatEnum.PCM);
|
|
|
+ transcriber.setSampleRate(SampleRateEnum.SAMPLE_RATE_16K);
|
|
|
+ transcriber.setEnableIntermediateResult(true);
|
|
|
+ transcriber.setEnablePunctuation(true);
|
|
|
+ transcriber.setEnableITN(false);
|
|
|
+ transcriber.start();
|
|
|
+
|
|
|
+ transcribers.put(sessionId, transcriber);
|
|
|
+
|
|
|
+ // 启动心跳线程,避免WebSocket会话超时
|
|
|
+ startHeartbeatThread(sessionId, transcriber);
|
|
|
+
|
|
|
+ return Map.of("success", true, "message", "语音识别会话已开始", "sessionId", sessionId);
|
|
|
+ } catch (Exception e) {
|
|
|
+ logger.error("开始语音识别失败", e);
|
|
|
+ cleanupSession(sessionId);
|
|
|
+ return Map.of("success", false, "message", "开始语音识别失败: " + e.getMessage());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 启动心跳线程,定期发送空数据保持WebSocket连接
|
|
|
+ */
|
|
|
+ private void startHeartbeatThread(String sessionId, SpeechTranscriber transcriber) {
|
|
|
+ Thread heartbeatThread = new Thread(() -> {
|
|
|
+ try {
|
|
|
+ while (transcribers.containsKey(sessionId)) {
|
|
|
+ // 发送空数据保持连接
|
|
|
+ transcriber.send(new byte[0], 0);
|
|
|
+ Thread.sleep(5000); // 每5秒发送一次心跳
|
|
|
+ }
|
|
|
+ } catch (Exception e) {
|
|
|
+ logger.error("心跳线程异常", e);
|
|
|
+ }
|
|
|
+ });
|
|
|
+ heartbeatThread.setDaemon(true);
|
|
|
+ heartbeatThread.start();
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 清理会话资源
|
|
|
+ */
|
|
|
+ private void cleanupSession(String sessionId) {
|
|
|
+ try {
|
|
|
+ SpeechTranscriber transcriber = transcribers.remove(sessionId);
|
|
|
+ if (transcriber != null) {
|
|
|
+ transcriber.close();
|
|
|
+ }
|
|
|
+ NlsClient client = nlsClients.remove(sessionId);
|
|
|
+ if (client != null) {
|
|
|
+ client.shutdown();
|
|
|
+ }
|
|
|
+ transcriptionResults.remove(sessionId);
|
|
|
+ latches.remove(sessionId);
|
|
|
+ } catch (Exception e) {
|
|
|
+ logger.error("清理会话资源失败", e);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 接收音频数据并发送到阿里云
|
|
|
+ */
|
|
|
+ @PostMapping("/stream")
|
|
|
+ public Map<String, Object> streamAudio(HttpServletRequest request) {
|
|
|
+ try {
|
|
|
+ // 从请求参数中获取sessionId
|
|
|
+ String sessionId = null;
|
|
|
+
|
|
|
+ // 首先尝试从URL参数中获取
|
|
|
+ sessionId = request.getParameter("sessionId");
|
|
|
+
|
|
|
+ // 如果URL参数中没有,尝试从multipart表单中获取
|
|
|
+ if (sessionId == null || sessionId.isEmpty()) {
|
|
|
+ if (request instanceof MultipartHttpServletRequest) {
|
|
|
+ MultipartHttpServletRequest multipartRequest = (MultipartHttpServletRequest) request;
|
|
|
+ sessionId = multipartRequest.getParameter("sessionId");
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if (sessionId == null || sessionId.isEmpty()) {
|
|
|
+ logger.error("sessionId参数缺失,请求参数: {}", request.getParameterMap());
|
|
|
+ return Map.of("success", false, "message", "请求参数缺失:sessionId");
|
|
|
+ }
|
|
|
+
|
|
|
+ // 获取音频文件
|
|
|
+ MultipartFile audioFile = null;
|
|
|
+ if (request instanceof MultipartHttpServletRequest) {
|
|
|
+ MultipartHttpServletRequest multipartRequest = (MultipartHttpServletRequest) request;
|
|
|
+ audioFile = multipartRequest.getFile("audio");
|
|
|
+ }
|
|
|
+
|
|
|
+ SpeechTranscriber transcriber = transcribers.get(sessionId);
|
|
|
+ if (transcriber == null) {
|
|
|
+ logger.error("会话不存在,sessionId: {}", sessionId);
|
|
|
+ return Map.of("success", false, "message", "语音识别会话未开始");
|
|
|
+ }
|
|
|
+
|
|
|
+ if (audioFile == null || audioFile.isEmpty()) {
|
|
|
+ // 忽略空数据,避免发送空数据到阿里云
|
|
|
+ logger.info("接收到空音频数据,忽略处理");
|
|
|
+ return Map.of("success", true, "message", "音频数据已接收");
|
|
|
+ }
|
|
|
+
|
|
|
+ byte[] audioData = audioFile.getBytes();
|
|
|
+ if (audioData.length > 0) {
|
|
|
+ transcriber.send(audioData, audioData.length);
|
|
|
+ } else {
|
|
|
+ // 忽略空数据,避免发送空数据到阿里云
|
|
|
+ logger.info("接收到空音频数据,忽略处理");
|
|
|
+ }
|
|
|
+
|
|
|
+ return Map.of("success", true, "message", "音频数据已接收");
|
|
|
+ } catch (Exception e) {
|
|
|
+ logger.error("处理音频数据失败", e);
|
|
|
+ return Map.of("success", false, "message", "处理音频数据失败: " + e.getMessage());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 结束语音识别会话并返回结果
|
|
|
+ */
|
|
|
+ @PostMapping("/stop")
|
|
|
+ public Map<String, Object> stopRecognition(@RequestParam("sessionId") String sessionId) {
|
|
|
+ try {
|
|
|
+ SpeechTranscriber transcriber = transcribers.get(sessionId);
|
|
|
+ CountDownLatch latch = latches.get(sessionId);
|
|
|
+
|
|
|
+ if (transcriber != null) {
|
|
|
+ try {
|
|
|
+ transcriber.stop();
|
|
|
+ // 等待识别完成
|
|
|
+ if (latch != null) {
|
|
|
+ latch.await(5, TimeUnit.SECONDS);
|
|
|
+ }
|
|
|
+ } catch (Exception e) {
|
|
|
+ logger.error("停止transcriber失败", e);
|
|
|
+ // 继续执行,确保资源被清理
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 获取识别结果
|
|
|
+ StringBuilder result = transcriptionResults.get(sessionId);
|
|
|
+ String finalResult = result != null ? result.toString() : "";
|
|
|
+
|
|
|
+ // 清理资源
|
|
|
+ cleanupSession(sessionId);
|
|
|
+
|
|
|
+ return Map.of("success", true, "result", finalResult);
|
|
|
+ } catch (Exception e) {
|
|
|
+ logger.error("停止语音识别失败", e);
|
|
|
+ // 确保资源被清理
|
|
|
+ cleanupSession(sessionId);
|
|
|
+ return Map.of("success", false, "message", "停止语音识别失败: " + e.getMessage());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 获取语音识别监听器
|
|
|
+ */
|
|
|
+ private SpeechTranscriberListener getTranscriberListener(String sessionId) {
|
|
|
+ return new SpeechTranscriberListener() {
|
|
|
+ @Override
|
|
|
+ public void onTranscriptionResultChange(SpeechTranscriberResponse response) {
|
|
|
+ String result = response.getTransSentenceText();
|
|
|
+ logger.info("实时识别中间结果: " + result);
|
|
|
+
|
|
|
+ // 通过WebSocket实时返回中间结果
|
|
|
+ WebSocketSession session = webSocketSessions.get(sessionId);
|
|
|
+ if (session != null && session.isOpen()) {
|
|
|
+ try {
|
|
|
+ session.sendMessage(new TextMessage("{\"type\":\"intermediate\",\"result\":\"" + result + "\"}"));
|
|
|
+ } catch (Exception e) {
|
|
|
+ logger.error("发送WebSocket消息失败", e);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void onTranscriberStart(SpeechTranscriberResponse response) {
|
|
|
+ logger.info("语音识别会话开始, task_id: " + response.getTaskId());
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void onSentenceBegin(SpeechTranscriberResponse response) {
|
|
|
+ logger.info("开始识别新句子");
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void onSentenceEnd(SpeechTranscriberResponse response) {
|
|
|
+ logger.info("句子识别完成, 结果: " + response.getTransSentenceText());
|
|
|
+ logger.info("置信度: " + response.getConfidence() + ", 开始时间: " + response.getSentenceBeginTime() + ", 处理时长: " + response.getTransSentenceTime() + "ms");
|
|
|
+ StringBuilder result = transcriptionResults.get(sessionId);
|
|
|
+ if (result != null) {
|
|
|
+ result.append(response.getTransSentenceText());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void onTranscriptionComplete(SpeechTranscriberResponse response) {
|
|
|
+ logger.info("语音识别会话完成");
|
|
|
+ CountDownLatch latch = latches.get(sessionId);
|
|
|
+ if (latch != null) {
|
|
|
+ latch.countDown();
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void onFail(SpeechTranscriberResponse response) {
|
|
|
+ logger.error("语音识别失败: " + response.getStatusText() + ", 状态码: " + response.getStatus());
|
|
|
+ CountDownLatch latch = latches.get(sessionId);
|
|
|
+ if (latch != null) {
|
|
|
+ latch.countDown();
|
|
|
+ }
|
|
|
+ }
|
|
|
+ };
|
|
|
+ }
|
|
|
+}
|