今天完成了OpenAI Functions调用的功能。今天对Functions调用的流程写一下简单的说明。
先上代码,调用主要集中在了三个类里,昨天文章中已经列了出来,今天又完善了代码。
基于OpenAI Function Call 结合Lucene实现本地化的知识搜索(一)
类名 | 作用说明 |
---|---|
AbsFunctionService | 函数服务的抽象类,定义基础的一些处理方法 |
FunctionResult | 统一返回值,包含code和data,code==0是正常返回 |
FunctionFactory | 函数处理工厂,生产函数对象和在过程中对数据处理 |
FunctionAnnotation | 需要注册的方法和参数的说明 |
WeatherApiService | 天气服务插件函数 |
基础服务类说明
package com.lgf.warehouse.modules.ai.openai.functions;
import cn.hutool.json.JSONObject;
import com.lgf.warehouse.core.chatgpt.entity.chat.Functions;
import com.lgf.warehouse.core.chatgpt.entity.chat.Parameters;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
/**
* 方法服务抽象类
*/
public abstract class AbsFunctionService {
public abstract String getFunctionName();
public abstract Class getCla();
/**
* 执行方法
* @param methodName
* @param params
* @return
*/
public Object execute(String methodName, Map<String,Object> params) throws NoSuchMethodException {
Method method=this.getMethodByName(methodName);
try {
Object[] objs=getArgumentsArray(params,method.getParameters());
return method.invoke(this ,objs);
} catch (Exception e) {
e.printStackTrace();
}
return null;
}
/**
* 根据方法的参数配置返回参数列表
* @param paramMap
* @param parameters
* @return
*/
private static Object[] getArgumentsArray(Map<String, Object> paramMap, Parameter[] parameters) {
Object[] argsArray = new Object[parameters.length];
for (int i = 0; i < parameters.length; i++) {
Parameter parameter = parameters[i];
Object paramValue = paramMap.getOrDefault(parameter.getName(), null);
argsArray[i] = paramValue;
}
return argsArray;
}
/**
* 根据方法名获取方法
* @param methodName
* @return
* @throws NoSuchMethodException
*/
private Method getMethodByName(String methodName) throws NoSuchMethodException {
Method[] methods=this.getCla().getMethods();
for(Method method:methods){
if(method.getName().equals(methodName)){
return method;
}
}
return this.getCla().getMethod(methodName);
}
/**
* 获取方法
* @return
* @throws NoSuchMethodException
*/
public List<Functions> getFunctions() throws NoSuchMethodException {
Class cla = this.getCla();
Method[] methods=cla.getMethods();
List<Functions> functions=new ArrayList<>();
for(Method method: methods) {
Annotation[] as=method.getAnnotations();
FunctionAnnotation methodFun = method.getAnnotation(FunctionAnnotation.class);
if(methodFun==null){
continue;
}
String description = methodFun.describe();
Parameter[] params = method.getParameters();
Parameters parameters = this.getParameters(params);
Functions function = Functions.builder()
.name(this.getFunctionName() + "_" + method.getName())
.description(description)
.parameters(parameters)
.build();
functions.add(function);
}
return functions;
}
/**
* 获取参数
*
* @param parameters
* @return
*/
private Parameters getParameters(Parameter[] parameters) {
JSONObject params = new JSONObject();
List<String> requireds = new ArrayList<>();
for (Parameter parameter : parameters) {
FunctionAnnotation paramFun = parameter.getAnnotation(FunctionAnnotation.class);
JSONObject param = new JSONObject();
String type = this.convertParameterToJsonSchemaType(parameter);
param.putOpt("type", type);
if (paramFun != null) {
if (paramFun.required()) {
requireds.add(parameter.getName());
}
if (paramFun.enums().length > 0) {
param.putOpt("enum", Arrays.asList(paramFun.enums()));
}
param.putOpt("description", paramFun.describe());
}
params.putOpt(parameter.getName(), param);
}
Parameters result = Parameters.builder()
.type("object")
.properties(params)
.required(requireds)
.build();
return result;
}
/**
* 将JAVA的参数类型转换为JSON Schema的类型
* @param parameter
* @return
*/
public String convertParameterToJsonSchemaType(Parameter parameter) {
String parameterTypeName = parameter.getType().getSimpleName().toLowerCase();
String jsonSchemaType = null;
switch (parameterTypeName) {
case "string":
jsonSchemaType = "string";
break;
case "boolean":
jsonSchemaType = "boolean";
break;
case "byte":
case "short":
case "int":
case "long":
case "float":
case "double":
jsonSchemaType = "number";
break;
default:
if (parameterTypeName.endsWith("[]")) {
jsonSchemaType = "array";
} else {
jsonSchemaType = "object";
}
}
return jsonSchemaType;
}
}
package com.lgf.warehouse.modules.ai.openai.functions;
import cn.hutool.json.JSONUtil;
import com.lgf.warehouse.core.Constants;
import lombok.Data;
import java.util.Objects;
/**
* 函数执行结果
*/
@Data
public class FunctionResult {
private Integer code;
private Object data;
public FunctionResult() {
this.code = Constants.Result.Code.SUCCESS;
}
public FunctionResult(Object data) {
this.code = Constants.Result.Code.SUCCESS;
this.data = data;
}
/**
* 失败
*
* @return
*/
public static FunctionResult failure() {
FunctionResult result = new FunctionResult();
result.setCode(Constants.Result.Code.ERROR);
return result;
}
/**
* 成功
*
* @param data
* @return
*/
public static FunctionResult ok(Object data) {
return new FunctionResult(data);
}
public boolean isSuccess() {
return Objects.equals(this.code, Constants.Result.Code.SUCCESS);
}
public String toString() {
if(this.getData()==null){
return "";
}else {
return JSONUtil.toJsonStr(this.getData());
}
}
}
统一一个返回值格式,方便后续的数据处理,并且将Result类的toString修改为data的JSON对象,OpenAI更理解JSON值的数据内容。
package com.lgf.warehouse.modules.ai.openai.functions;
import com.lgf.warehouse.core.chatgpt.entity.chat.ChatChoice;
import com.lgf.warehouse.core.chatgpt.entity.chat.FunctionCall;
import com.lgf.warehouse.core.chatgpt.entity.chat.Functions;
import com.lgf.warehouse.core.chatgpt.entity.chat.Message;
import com.lgf.warehouse.modules.ai.openai.functions.weather.service.WeatherApiService;
import jakarta.annotation.PostConstruct;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.*;
/**
* 方法生产工厂
*/
@Slf4j
@Service
public class FunctionFactory {
@Autowired
private WeatherApiService weatherApiService;
// 函数定义集合
private Map<String, AbsFunctionService> functionBeanMap;
private List<Functions> functions;
/**
* 注册服务到FunctionFactory
*/
@PostConstruct()
public void register() throws NoSuchMethodException {
this.functionBeanMap = new HashMap<>();
this.functions = new ArrayList<>();
//注册天气服务
this.registerFunction(this.weatherApiService);
log.info(String.format("注册了%d个服务", this.functionBeanMap.size()));
}
private void registerFunction(AbsFunctionService functionService) throws NoSuchMethodException {
this.functionBeanMap.put(functionService.getFunctionName(), functionService);
this.functions.addAll(functionService.getFunctions());
}
// 定义公共的返回值
public FunctionResult execute(String functionName, Map<String, Object> params) throws NoSuchMethodException {
String[] sp = functionName.split("_");
if (sp.length != 2) {
throw new RuntimeException("方法名称不正确");
}
String beanName = sp[0];
AbsFunctionService functionService = this.functionBeanMap.get(beanName);
if (functionService == null) {
throw new RuntimeException("没有找到对应的Bean");
}
//通过反射执行functionService中的functionName名称的方法
Object result = functionService.execute(sp[1], params);
if (result == null) {
return FunctionResult.failure();
} else {
return FunctionResult.ok(result);
}
}
/**
* 包装Messages,在第二次请求的时候,将上下文信息都传入到OpenAI中,方便OpenAI理解,然后回复我
* @param message
* @param chatChoice
* @param functionResult
* @return
*/
public List<Message> getMessages(Message message, ChatChoice chatChoice,FunctionResult functionResult) {
FunctionCall fc = chatChoice.getMessage().getFunctionCall();
FunctionCall functionCall = FunctionCall.builder()
.arguments(fc.getArguments())
.name(fc.getName())
.build();
//辅助消息,说明方法的参数信息
Message message2 = Message.builder().role(Message.Role.ASSISTANT).content("方法参数").functionCall(functionCall).build();
String content = functionResult.toString();
//方法说明的消息,将方法的返回值告诉OpenAI
Message message3 = Message.builder().role(Message.Role.FUNCTION).name(fc.getName()).content(content).build();
//把问题串起来给OpenAI
List<Message> messageList = Arrays.asList(message, message2, message3);
return messageList;
}
/**
* 获取方法集合
*
* @return
*/
public List<Functions> getFunctions() {
return this.functions;
}
}
这个是接下来系统的核心处理方法,用来调用方法和创建方法。
package com.lgf.warehouse.modules.ai.openai.functions;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.util.Map;
/**
* GPT函数的注释
*/
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.METHOD,ElementType.PARAMETER})
public @interface FunctionAnnotation {
/**
* 描述
* @return
*/
public String describe() default "";
/**
* 枚举数据
* @return
*/
public String[] enums() default {};
/**
* 是否必须
* @return
*/
public boolean required() default false;
}
这个需要注意,必须在方法上定义,参数上尽量全部加上描述。
Demo
package com.lgf.warehouse.modules.ai.openai.functions.weather.service;
import cn.hutool.http.HttpUtil;
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import com.lgf.warehouse.modules.ai.openai.functions.AbsFunctionService;
import com.lgf.warehouse.modules.ai.openai.functions.FunctionAnnotation;
import com.lgf.warehouse.modules.ai.openai.functions.weather.vo.Weather;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import java.util.HashMap;
import java.util.Map;
@Service
public class WeatherApiService extends AbsFunctionService {
@Value("${functions.weather.api.url}")
private String api;
@Value("${functions.weather.api.key}")
private String key;
@FunctionAnnotation(describe = "获取天气预报")
public Weather getWeatherByCity(@FunctionAnnotation(describe = "城市名称,如果是中文需要转换为拼音,比如洛阳,需要转换为LuoYang",required = true) String city,
@FunctionAnnotation(describe = "要查询的是哪一天的天气,今天是1,明天是2,后天是3,最多查询3天的天气,如果超出3天,则返回1",required = true,enums = {"1","2","3"}) Integer days) {
//调用天气接口
Map<String,Object> params=new HashMap<>();
params.put("key",this.key);
params.put("q",city);
params.put("days",days);
String result= HttpUtil.get(this.api,params);
if(JSONUtil.isTypeJSONObject(result)){
JSONObject response=JSONUtil.parseObj(result);
if(response.get("current")!=null){
JSONObject location=response.getJSONObject("location");
JSONObject current=response.getJSONObject("current");
JSONObject condition=current.getJSONObject("condition");
// 解析天气接口返回的数据
Weather weather= new Weather();
weather.setCity(location.getStr("name"));
weather.setRegion(location.getStr("region"));
weather.setCountry(location.getStr("country"));
weather.setDays(days);
weather.setCondition(condition.getStr("text"));
weather.setConditionIcon(condition.getStr("icon"));
weather.setTemp(current.getDouble("temp_c"));
return weather;
}
}
return null;
}
@Override
public String getFunctionName() {
return "WeatherApiService";
}
@Override
public Class getCla() {
return this.getClass();
}
}
天气接口用的是api.weatherapi.com/v1/current.… 免费接口,缺点是只能查询市级的天气,并且需要英文(中文地市就是拼音)。
结尾
夸一下GitHub的Copilot,特别是处理这种API接口的时候,调用一次接口,把返回值复制到Idea中后注释掉,然后写代码的时候会自动生成解析结构的代码,太爽了,减少这种码字工作。
© 版权声明
文章版权归作者所有,未经允许请勿转载,侵权请联系 admin@trc20.tw 删除。
THE END