前往讯飞开发平台选择产品,获取appId、apiKey、APISecret,这里我选择的是v3.0模型。
java后端实现
本项目以及实现了基本的会话功能,小伙伴可以自己扩充其他的例如绘画功能。
注意:星火模型的api使用的是websocket协议,和chatGPT的http不一样,而且他的key也不但单单是APIKey,需要用特殊的算法得到验证令牌。详情可查看官网开发文档。
下面就来看看实现:
pom.xml,主要依赖
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.xyb</groupId>
<artifactId>xfxh-sdk-java</artifactId>
<version>1.0-SNAPSHOT</version>
<name>xfxh-sdk-java</name>
<description>Demo project for Spring Boot</description>
<properties>
<java.version>1.8</java.version>
</properties>
<dependencies>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>1.18.24</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>2.0.6</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-simple</artifactId>
<version>2.0.6</version>
</dependency>
<dependency>
<groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId>
<version>5.8.18</version>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
<version>1.2.67</version>
</dependency>
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp-sse</artifactId>
<version>3.14.9</version>
</dependency>
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>logging-interceptor</artifactId>
<version>3.14.9</version>
</dependency>
<dependency>
<groupId>com.squareup.retrofit2</groupId>
<artifactId>retrofit</artifactId>
<version>2.9.0</version>
</dependency>
<dependency>
<groupId>com.squareup.retrofit2</groupId>
<artifactId>converter-jackson</artifactId>
<version>2.9.0</version>
</dependency>
<dependency>
<groupId>com.squareup.retrofit2</groupId>
<artifactId>adapter-rxjava2</artifactId>
<version>2.9.0</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.13.2</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.jetbrains</groupId>
<artifactId>annotations</artifactId>
<version>RELEASE</version>
<scope>compile</scope>
</dependency>
</dependencies>
<build>
<finalName>chatgpt-sdk-java</finalName>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>2.12.4</version>
<configuration>
<skipTests>true</skipTests>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>8</source>
<target>8</target>
</configuration>
</plugin>
</plugins>
</build>
</project>
请求参数
package com.xyb.xfxh.dto;
import com.alibaba.fastjson.annotation.JSONField;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
/**
* 请求参数
*/
@NoArgsConstructor
@Data
@Builder
@AllArgsConstructor
public class RequestDTO {
@JsonProperty("header")
private HeaderDTO header;
@JsonProperty("parameter")
private ParameterDTO parameter;
@JsonProperty("payload")
private PayloadDTO payload;
@NoArgsConstructor
@Data
@AllArgsConstructor
@Builder
public static class HeaderDTO {
/**
* 应用appid,从开放平台控制台创建的应用中获取
*/
@JSONField(name = "app_id")
private String appId;
/**
* 每个用户的id,用于区分不同用户,最大长度32
*/
@JSONField(name = "uid")
private String uid;
}
@NoArgsConstructor
@Data
@AllArgsConstructor
@Builder
public static class ParameterDTO {
private ChatDTO chat;
@NoArgsConstructor
@Data
@AllArgsConstructor
@Builder
public static class ChatDTO {
/**
* 指定访问的领域,generalv3指向V3.0版本!
*/
@JsonProperty("domain")
private String domain = "generalv3";
/**
* 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高
*/
@JsonProperty("temperature")
private Float temperature = 0.5F;
/**
* 模型回答的tokens的最大长度
*/
@JSONField(name = "max_tokens")
private Integer maxTokens = 2048;
}
}
@NoArgsConstructor
@Data
@AllArgsConstructor
@Builder
public static class PayloadDTO {
@JsonProperty("message")
private MessageDTO message;
@NoArgsConstructor
@Data
@AllArgsConstructor
@Builder
public static class MessageDTO {
@JsonProperty("text")
private List<MsgDTO> text;
}
}
}
注意:domain具体看你选择的模型
响应参数
package com.xyb.xfxh.dto;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
@NoArgsConstructor
@Data
public class ResponseDTO {
@JsonProperty("header")
private HeaderDTO header;
@JsonProperty("payload")
private PayloadDTO payload;
@NoArgsConstructor
@Data
public static class HeaderDTO {
/**
* 错误码,0表示正常,非0表示出错
*/
@JsonProperty("code")
private Integer code;
/**
* 会话是否成功的描述信息
*/
@JsonProperty("message")
private String message;
/**
* 会话的唯一id,用于讯飞技术人员查询服务端会话日志使用,出现调用错误时建议留存该字段
*/
@JsonProperty("sid")
private String sid;
/**
* 会话状态,取值为[0,1,2];0代表首次结果;1代表中间结果;2代表最后一个结果
*/
@JsonProperty("status")
private Integer status;
}
@NoArgsConstructor
@Data
public static class PayloadDTO {
@JsonProperty("choices")
private ChoicesDTO choices;
/**
* 在最后一次结果返回
*/
@JsonProperty("usage")
private UsageDTO usage;
@NoArgsConstructor
@Data
public static class ChoicesDTO {
/**
* 文本响应状态,取值为[0,1,2]; 0代表首个文本结果;1代表中间文本结果;2代表最后一个文本结果
*/
@JsonProperty("status")
private Integer status;
/**
* 返回的数据序号,取值为[0,9999999]
*/
@JsonProperty("seq")
private Integer seq;
/**
* 响应文本
*/
@JsonProperty("text")
private List<MsgDTO> text;
}
@NoArgsConstructor
@Data
public static class UsageDTO {
@JsonProperty("text")
private TextDTO text;
@NoArgsConstructor
@Data
public static class TextDTO {
/**
* 保留字段,可忽略
*/
@JsonProperty("question_tokens")
private Integer questionTokens;
/**
* 包含历史问题的总tokens大小
*/
@JsonProperty("prompt_tokens")
private Integer promptTokens;
/**
* 回答的tokens大小
*/
@JsonProperty("completion_tokens")
private Integer completionTokens;
/**
* prompt_tokens和completion_tokens的和,也是本次交互计费的tokens大小
*/
@JsonProperty("total_tokens")
private Integer totalTokens;
}
}
}
}
MsgDto类
package com.xyb.xfxh.dto;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@AllArgsConstructor
@NoArgsConstructor
@Builder
@JsonInclude(JsonInclude.Include.NON_NULL)
public class MsgDTO {
/**
* 角色
*/
private String role;
/**
* 消息内容
*/
private String content;
/**
* 响应结果字段:结果序号,取值为[0,10]; 当前为保留字段,开发者可忽略
*/
private Integer index;
public static final String ROLE_USER = "user";
public static final String ROLE_ASSISTANT = "assistant";
public static MsgDTO createUserMsg(String content) {
return new MsgDTO(ROLE_USER, content, null);
}
public static MsgDTO createAssistantMsg(String content) {
return new MsgDTO(ROLE_ASSISTANT, content, null);
}
}
定义接口
public interface IOpenAiApi {
String v1_chat_completions = "v3.1/chat/";
/**
* 默认 星火认知大模型 问答模型
* @param chatCompletionRequest 请求信息
* @return 返回结果
*/
@POST(v1_chat_completions)
Single<ResponseDTO> completions(@Body RequestDTO chatCompletionRequest);
}
在IOpenAiApi 接口中定义访问接口,目前只有简单问答模型。
会话接口
public interface OpenAiSession {
/**
* 星火认知大模型
* @param requestDTO
* @param
* @return
*/
WebSocket completions(RequestDTO requestDTO, WebSocketListener listener) throws Exception;
/**
* 星火认知大模型, 用自己的数据
* @param requestDTO
* @param
* @return
*/
WebSocket completions(String apiHost, String apiKey, RequestDTO requestDTO, WebSocketListener listener) throws Exception;
会话接口 OpenAiSession 与 IOpenAiApi 看上去是有些类似的。但有了这样一个接口,就可以封装出各类需要的扩展方法了。
会话工厂
public class DefaultOpenAiSessionFactory implements OpenAiSessionFactory {
private final Configuration configuration;
public DefaultOpenAiSessionFactory(Configuration configuration) {
this.configuration = configuration;
}
@Override
public OpenAiSession openSession() {
// 日志配置
HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor();
httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.HEADERS);
// 创建http客户端
OpenAiInterceptor openAiInterceptor = new OpenAiInterceptor(configuration.getApiKey(), configuration);
OkHttpClient okHttpClient = new OkHttpClient
.Builder()
.addInterceptor(httpLoggingInterceptor)
// .addInterceptor(openAiInterceptor)
.connectTimeout(450, TimeUnit.SECONDS)
.writeTimeout(450, TimeUnit.SECONDS)
.readTimeout(450, TimeUnit.SECONDS)
.build();
configuration.setOkHttpClient(okHttpClient);
// 开启openai会话
IOpenAiApi openAiApi = new Retrofit.Builder()
.baseUrl(configuration.getApiHost())
.client(okHttpClient)
.addCallAdapterFactory(RxJava2CallAdapterFactory.create())
.addConverterFactory(JacksonConverterFactory.create())
.build().create(IOpenAiApi.class);
configuration.setOpenAiApi(openAiApi);
return new DefaultOpenAiSession(configuration);
}
}
本来的目的是为了把HTTP相关的配置统一在这里配好,结果星火API不支持http协议,如果某个api是用http请求的,可以在这里进行统一配置。
会话接口的实现
public class DefaultOpenAiSession implements OpenAiSession {
/** 配置信息 */
private final Configuration configuration;
private final EventSource.Factory factory;
private final IOpenAiApi openAiApi;
private static final String V = "v3.1/chat";
public DefaultOpenAiSession(Configuration configuration) {
this.configuration = configuration;
this.openAiApi = configuration.getOpenAiApi();
this.factory = configuration.createRequestFactory();
}
@Override
public WebSocket completions(RequestDTO chatCompletionRequest, WebSocketListener listener) throws Exception {
return this.completions(null, null, chatCompletionRequest, listener);
}
@Override
public WebSocket completions(String apiHostByUser, String apiKeyByUser, RequestDTO chatCompletionRequest, WebSocketListener listener) throws Exception {
// 动态设置 Host、Key,便于用户传递自己的信息
String apiHost = apiHostByUser == null ? configuration.getApiHost() : apiHostByUser;
String apiKey = apiKeyByUser == null ? configuration.getApiKey() : apiKeyByUser;
// 构建请求信息
String key = AuthUtil.getKey(apiKey, configuration);
System.out.println(key);
Request request = new Request.Builder()
// 这里的url需注意,需要提前处理好key,具体请前往讯飞开发平台查看开发文档
// 参考格式:wss://spark-api.xf-yun.com/v1.1/chat?authorization=YXBpX2tleT0iYWRkZDIyNzJiNmQ4YjdjOGFiZGQ3OTUzMTQyMGNhM2IiLCBhbGdvcml0aG09ImhtYWMtc2hhMjU2IiwgaGVhZGVycz0iaG9zdCBkYXRlIHJlcXVlc3QtbGluZSIsIHNpZ25hdHVyZT0iejVnSGR1M3B4VlY0QURNeWs0Njd3T1dEUTlxNkJRelIzbmZNVGpjL0RhUT0i&date=Fri%2C+05+May+2023+10%3A43%3A39+GMT&host=spark-api.xf-yun.com
.url(key)
.build();
// 建立 wss 连接
OkHttpClient okHttpClient = new OkHttpClient.Builder().build();
WebSocket webSocket = okHttpClient.newWebSocket(request, listener);
// 发送请求
webSocket.send(JSONObject.toJSONString(chatCompletionRequest));
// 返回结果信息
return webSocket;
}
}
这样就可以通过OpenAiSession接口调用此服务啦
websocket鉴权
开发者需要自行先在控制台创建应用,利用应用中提供的appid,APIKey, APISecret进行鉴权,生成最终请求的鉴权url
public class AuthUtil {
public static String getKey(String apiKeyBySystem, Configuration configuration) throws Exception {
// 时间
SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US);
format.setTimeZone(TimeZone.getTimeZone("GMT"));
String date = format.format(new Date());
// 拼接
URL url = new URL(configuration.getApiHost().concat("v3.1/chat"));
String preStr = "host: " + url.getHost() + "\n" +
"date: " + date + "\n" +
"GET " + url.getPath() + " HTTP/1.1";
System.out.println(preStr);
// SHA256加密
Mac mac = Mac.getInstance("hmacsha256");
SecretKeySpec spec = new SecretKeySpec(configuration.getApiSecret().getBytes(StandardCharsets.UTF_8), "hmacsha256");
mac.init(spec);
byte[] hexDigits = mac.doFinal(preStr.getBytes(StandardCharsets.UTF_8));
// Base64加密
String sha = Base64.getEncoder().encodeToString(hexDigits);
// 拼接
String authorizationOrigin = String.format("api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKeyBySystem, "hmac-sha256", "host date request-line", sha);
String encodeToString = Base64.getEncoder().encodeToString(authorizationOrigin.getBytes(StandardCharsets.UTF_8));
String most_url = url+"?authorization="+encodeToString+"&date="+date+"&host="+url.getHost();
return most_url;
}
}
一定得注意星火api的url格式,建议详读开发文档,星火认知大模型服务说明 | 讯飞开放平台文档中心 (xfyun.cn)
下面是测试代码
@Slf4j
public class ApiTest {
private OpenAiSession openAiSession;
private StringBuilder answer = new StringBuilder();
@Before
public void test_OpenAiSessionFactory() {
// 1. 配置文件
Configuration configuration = new Configuration();
configuration.setAppId("你的appId");
configuration.setApiHost("https://spark-api.xf-yun.com/");
configuration.setApiKey("你的apikey");
configuration.setApiSecret("你的apiSecret");
// 可以根据课程首页评论置顶说明获取 apihost、apikey;https://t.zsxq.com/0d3o5FKvc
// configuration.setAuthToken("eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ4ZmciLCJleHAiOjE2ODMyNzIyMjAsImlhdCI6MTY4MzI2ODYyMCwianRpIjoiOTkwMmM4MjItNzI2MC00OGEwLWI0NDUtN2UwZGZhOGVhOGYwIiwidXNlcm5hbWUiOiJ4ZmcifQ.Om7SdWdiIevvaWdPn7D9PnWS-ZmgbNodYTh04Tfb124");
// 2. 会话工厂
OpenAiSessionFactory factory = new DefaultOpenAiSessionFactory(configuration);
// 3. 开启会话
this.openAiSession = factory.openSession();
}
/**
* 【常用对话模式,推荐使用此模型进行测试】
* 此对话模型 3.0 接近于官网体验 & 流式应答
*/
@Test
public void test_chat_completions_stream_channel() throws Exception {
RequestDTO chatCompletion = RequestDTO
.builder()
.header(RequestDTO.HeaderDTO.builder().appId("你的appId").uid("111").build())
.parameter(RequestDTO.ParameterDTO.builder().chat(RequestDTO.ParameterDTO.ChatDTO.builder().domain("generalv3").maxTokens(2048).temperature(0.5F).build()).build())
.payload(RequestDTO.PayloadDTO.builder().message(RequestDTO.PayloadDTO.MessageDTO.builder().text(Collections.singletonList(MsgDTO.builder().role("user").content("你是谁").index(1).build())).build()).build()).build();
// 3. 发起请求
WebSocket webSocket = openAiSession.completions(chatCompletion, new WebSocketListener() {
@Override
public void onOpen(WebSocket webSocket, Response response) {
super.onOpen(webSocket, response);
log.info("连接成功");
}
@Override
public void onMessage(WebSocket webSocket, String text) {
super.onMessage(webSocket, text);
// 将大模型回复的 JSON 文本转为 ResponseDTO 对象
ResponseDTO responseData = JSONObject.parseObject(text, ResponseDTO.class);
// 如果响应数据中的 header 的 code 值不为 0,则表示响应错误
if (responseData.getHeader().getCode() != 0) {
// 日志记录
log.error("发生错误,错误码为:" + responseData.getHeader().getCode() + "; " + "信息:" + responseData.getHeader().getMessage());
return;
}
// 将回答进行拼接
for (MsgDTO msgDTO : responseData.getPayload().getChoices().getText()) {
// apiTest.answer.append(msgDTO.getContent());
log.info("text:"+msgDTO.getContent());
}
/* // 对最后一个文本结果进行处理
if (2 == responseData.getHeader().getStatus()) {
wsCloseFlag = true;
}*/
}
@Override
public void onFailure(WebSocket webSocket, Throwable t, Response response) {
super.onFailure(webSocket, t, response);
log.error("error:"+response.message());
}
});
// 等待
new CountDownLatch(1).await();
}
}
现在应该就能正常使用讯飞的星火大模型了