Bläddra i källkod

新增tts文转音方法,默认pcm格式

liyanbo 2 månader sedan
förälder
incheckning
b2457273eb

+ 12 - 0
byzs-module-ai/src/main/java/cn/iocoder/byzs/module/ai/controller/admin/tts/AiTtsController.java

@@ -19,6 +19,7 @@ import io.swagger.v3.oas.annotations.tags.Tag;
 import jakarta.annotation.Resource;
 import jakarta.servlet.http.HttpServletResponse;
 import jakarta.validation.Valid;
+import lombok.extern.slf4j.Slf4j;
 import org.springframework.security.access.prepost.PreAuthorize;
 import org.springframework.validation.annotation.Validated;
 import org.springframework.web.bind.annotation.*;
@@ -34,6 +35,7 @@ import static cn.iocoder.byzs.framework.common.util.collection.CollectionUtils.c
 @RestController
 @RequestMapping("/ai/tts")
 @Validated
+@Slf4j
 public class AiTtsController {
 
     @Resource
@@ -111,4 +113,14 @@ public class AiTtsController {
                 .setId(tool.getId()).setName(tool.getName())));
     }
 
+    @PostMapping("/convert")
+    @Operation(summary = "文本转语音")
+    @PreAuthorize("@ss.hasPermission('ai:tts:convert')")
+    public CommonResult<String> convertTextToSpeech(@RequestParam("roleId") Long roleId, @RequestParam("content") String content) {
+        // 执行文本转语音
+        String audioUrl = ttsService.convertTextToSpeech(roleId, content);
+        // 返回结果
+        return success(audioUrl);
+    }
+
 }

+ 9 - 0
byzs-module-ai/src/main/java/cn/iocoder/byzs/module/ai/service/tts/AiTtsService.java

@@ -68,4 +68,13 @@ public interface AiTtsService {
      */
     List<AiTtsDO> getTtsSimpleListByStatus(Integer status);
 
+    /**
+     * 文本转语音
+     *
+     * @param roleId 角色编号
+     * @param content 需要转语音的内容
+     * @return 语音文件URL
+     */
+    String convertTextToSpeech(Long roleId, String content);
+
 }

+ 67 - 0
byzs-module-ai/src/main/java/cn/iocoder/byzs/module/ai/service/tts/AiTtsServiceImpl.java

@@ -5,8 +5,13 @@ import cn.iocoder.byzs.framework.common.pojo.PageResult;
 import cn.iocoder.byzs.framework.common.util.object.BeanUtils;
 import cn.iocoder.byzs.module.ai.controller.admin.tts.vo.AiTtsPageReqVO;
 import cn.iocoder.byzs.module.ai.controller.admin.tts.vo.AiTtsSaveReqVO;
+import cn.iocoder.byzs.module.ai.dal.dataobject.model.AiChatRoleDO;
 import cn.iocoder.byzs.module.ai.dal.dataobject.tts.AiTtsDO;
+import cn.iocoder.byzs.module.ai.dal.mysql.model.AiChatRoleMapper;
 import cn.iocoder.byzs.module.ai.dal.mysql.tts.AiTtsMapper;
+import cn.iocoder.byzs.module.ai.util.tts.StreamTtsService;
+import cn.iocoder.byzs.module.infra.api.file.FileApi;
+import com.alibaba.nls.client.protocol.OutputFormatEnum;
 import jakarta.annotation.Resource;
 import org.springframework.stereotype.Service;
 import org.springframework.validation.annotation.Validated;
@@ -28,6 +33,15 @@ public class AiTtsServiceImpl implements AiTtsService {
     @Resource
     private AiTtsMapper ttsMapper;
 
+    @Resource
+    private AiChatRoleMapper chatRoleMapper;
+
+    @Resource
+    private org.springframework.beans.factory.ObjectProvider<StreamTtsService> streamTtsServiceProvider;
+
+    @Resource
+    private FileApi fileApi;
+
     @Override
     public Long createTts(AiTtsSaveReqVO createReqVO) {
         // 插入
@@ -90,4 +104,57 @@ public class AiTtsServiceImpl implements AiTtsService {
         return ttsMapper.getTtsSimpleListByStatus(status);
     }
 
+    @Override
+    public String convertTextToSpeech(Long roleId, String content) {
+        // 1. 根据角色id查询角色信息
+        AiChatRoleDO chatRole = chatRoleMapper.selectById(roleId);
+        if (chatRole == null) {
+            throw exception(TTS_NOT_EXISTS);
+        }
+
+        // 2. 根据角色的ttsId查询TTS配置
+        Long ttsId = chatRole.getTtsId();
+        if (ttsId == null) {
+            throw exception(TTS_NOT_EXISTS);
+        }
+        AiTtsDO aiTtsDO = ttsMapper.selectById(ttsId);
+        if (aiTtsDO == null) {
+            throw exception(TTS_NOT_EXISTS);
+        }
+
+        // 3. 使用StreamTtsService将文本转语音
+        StreamTtsService streamTtsService = streamTtsServiceProvider.getObject();
+        try {
+            // 创建音频数据缓冲区
+            java.io.ByteArrayOutputStream audioOutputStream = new java.io.ByteArrayOutputStream();
+            // 设置音频数据回调
+            streamTtsService.setAudioDataCallback(audioData -> {
+                try {
+                    audioOutputStream.write(audioData);
+                } catch (java.io.IOException e) {
+                    throw new RuntimeException("写入音频数据失败", e);
+                }
+            });
+
+            // 开始TTS语音合成
+            streamTtsService.startTts(aiTtsDO, OutputFormatEnum.MP3);
+            // 发送文本
+            streamTtsService.sendText(content);
+            // 停止TTS
+            streamTtsService.stopTts();
+
+            // 4. 存储语音文件并上传到服务器
+            byte[] mp3Data = audioOutputStream.toByteArray();
+            String filePath = fileApi.createFile(mp3Data);
+            return filePath;
+        } catch (Exception e) {
+            throw new RuntimeException("文本转语音失败", e);
+        } finally {
+            // 确保资源被释放
+            if (streamTtsService != null) {
+                streamTtsService.stopTts();
+            }
+        }
+    }
+
 }

+ 10 - 5
byzs-module-ai/src/main/java/cn/iocoder/byzs/module/ai/util/tts/StreamTtsService.java

@@ -10,9 +10,7 @@ import com.alibaba.nls.client.protocol.tts.StreamInputTtsListener;
 import com.alibaba.nls.client.protocol.tts.StreamInputTtsResponse;
 import jakarta.annotation.PostConstruct;
 import jakarta.annotation.PreDestroy;
-import jakarta.annotation.Resource;
 import lombok.extern.slf4j.Slf4j;
-import org.springframework.beans.factory.ObjectProvider;
 import org.springframework.beans.factory.annotation.Value;
 import org.springframework.context.annotation.Scope;
 import org.springframework.stereotype.Service;
@@ -73,6 +71,13 @@ public class StreamTtsService {
      * 开始TTS语音合成
      */
     public void startTts(AiTtsDO aiTtsDO) {
+        startTts(aiTtsDO, OutputFormatEnum.PCM);
+    }
+
+    /**
+     * 开始TTS语音合成
+     */
+    public void startTts(AiTtsDO aiTtsDO, OutputFormatEnum format) {
         // 创建TTS实例
         try {
             synthesizer = new StreamInputTts(client, getSynthesizerListener());
@@ -81,8 +86,8 @@ public class StreamTtsService {
             throw new RuntimeException("创建TTS实例", e);
         }
         synthesizer.setAppKey(appKey);
-        synthesizer.setFormat(OutputFormatEnum.PCM); // 确保输出PCM格式
-        synthesizer.setSampleRate(SampleRateEnum.SAMPLE_RATE_16K); // 24000Hz采样率
+        synthesizer.setFormat(format); // 设置输出格式
+        synthesizer.setSampleRate(SampleRateEnum.SAMPLE_RATE_16K); // 16000Hz采样率
         synthesizer.setVoice(aiTtsDO.getModel());
         synthesizer.setVolume(aiTtsDO.getVolume());
         synthesizer.setPitchRate(aiTtsDO.getPitchRate());
@@ -219,4 +224,4 @@ public class StreamTtsService {
         this.onCompleteCallback = callback;
     }
 
-}
+}