基于OpenAI Function Call 结合Lucene实现本地化的知识搜索(二):Functions调用流程说明

今天完成了OpenAI Functions调用的功能。今天对Functions调用的流程写一下简单的说明。

先上代码,调用主要集中在了三个类里,昨天文章中已经列了出来,今天又完善了代码。

基于OpenAI Function Call 结合Lucene实现本地化的知识搜索(一)

image.png

类名 作用说明
AbsFunctionService 函数服务的抽象类,定义基础的一些处理方法
FunctionResult 统一返回值,包含code和data,code==0是正常返回
FunctionFactory 函数处理工厂,生产函数对象和在过程中对数据处理
FunctionAnnotation 需要注册的方法和参数的说明
WeatherApiService 天气服务插件函数

基础服务类说明

package com.lgf.warehouse.modules.ai.openai.functions;








import cn.hutool.json.JSONObject;
import com.lgf.warehouse.core.chatgpt.entity.chat.Functions;
import com.lgf.warehouse.core.chatgpt.entity.chat.Parameters;


import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;





/**

 * 方法服务抽象类
 */

public abstract class AbsFunctionService {

    public abstract String getFunctionName();


    public abstract Class getCla();



    /**
     * 执行方法
     * @param methodName
     * @param params
     * @return
     */
    public Object execute(String methodName, Map<String,Object> params) throws NoSuchMethodException {
        Method method=this.getMethodByName(methodName);
        try {
            Object[] objs=getArgumentsArray(params,method.getParameters());
            return method.invoke(this ,objs);
        } catch (Exception e) {
            e.printStackTrace();
        }
        return null;
    }




    /**
     * 根据方法的参数配置返回参数列表
     * @param paramMap
     * @param parameters
     * @return
     */
    private static Object[] getArgumentsArray(Map<String, Object> paramMap, Parameter[] parameters) {
        Object[] argsArray = new Object[parameters.length];
        for (int i = 0; i < parameters.length; i++) {
            Parameter parameter = parameters[i];
            Object paramValue = paramMap.getOrDefault(parameter.getName(), null);
            argsArray[i] = paramValue;
        }
        return argsArray;
    }

    /**
     * 根据方法名获取方法
     * @param methodName
     * @return
     * @throws NoSuchMethodException
     */
    private Method getMethodByName(String methodName) throws NoSuchMethodException {
        Method[] methods=this.getCla().getMethods();
        for(Method method:methods){
            if(method.getName().equals(methodName)){
                return method;
            }
        }
        return this.getCla().getMethod(methodName);
    }

    /**
     * 获取方法
     * @return
     * @throws NoSuchMethodException
     */
    public List<Functions> getFunctions() throws NoSuchMethodException {
        Class cla = this.getCla();
        Method[] methods=cla.getMethods();
        List<Functions> functions=new ArrayList<>();
        for(Method method: methods) {
            Annotation[] as=method.getAnnotations();
            FunctionAnnotation methodFun = method.getAnnotation(FunctionAnnotation.class);
            if(methodFun==null){
                continue;
            }
            String description = methodFun.describe();
            Parameter[] params = method.getParameters();
            Parameters parameters = this.getParameters(params);
            Functions function = Functions.builder()
                    .name(this.getFunctionName() + "_" + method.getName())
                    .description(description)
                    .parameters(parameters)
                    .build();
            functions.add(function);
        }
        return functions;
    }

    /**
     * 获取参数
     *
     * @param parameters
     * @return
     */
    private Parameters getParameters(Parameter[] parameters) {
        JSONObject params = new JSONObject();
        List<String> requireds = new ArrayList<>();
        for (Parameter parameter : parameters) {
            FunctionAnnotation paramFun = parameter.getAnnotation(FunctionAnnotation.class);

            JSONObject param = new JSONObject();
            String type = this.convertParameterToJsonSchemaType(parameter);
            param.putOpt("type", type);

            if (paramFun != null) {
                if (paramFun.required()) {
                    requireds.add(parameter.getName());
                }
                if (paramFun.enums().length > 0) {
                    param.putOpt("enum", Arrays.asList(paramFun.enums()));
                }
                param.putOpt("description", paramFun.describe());
            }
            params.putOpt(parameter.getName(), param);
        }
        Parameters result = Parameters.builder()
                .type("object")
                .properties(params)
                .required(requireds)
                .build();

        return result;
    }

    /**
     * 将JAVA的参数类型转换为JSON Schema的类型
     * @param parameter
     * @return
     */
    public String convertParameterToJsonSchemaType(Parameter parameter) {
        String parameterTypeName = parameter.getType().getSimpleName().toLowerCase();
        String jsonSchemaType = null;
        switch (parameterTypeName) {
            case "string":
                jsonSchemaType = "string";
                break;
            case "boolean":
                jsonSchemaType = "boolean";
                break;
            case "byte":
            case "short":
            case "int":
            case "long":
            case "float":
            case "double":
                jsonSchemaType = "number";
                break;
            default:
                if (parameterTypeName.endsWith("[]")) {
                    jsonSchemaType = "array";
                } else {
                    jsonSchemaType = "object";
                }
        }
        return jsonSchemaType;
    }
}
package com.lgf.warehouse.modules.ai.openai.functions;








import cn.hutool.json.JSONUtil;
import com.lgf.warehouse.core.Constants;
import lombok.Data;


import java.util.Objects;


/**

 * 函数执行结果
 */

@Data
public class FunctionResult {




    private Integer code;

    private Object data;

    public FunctionResult() {
        this.code = Constants.Result.Code.SUCCESS;
    }


    public FunctionResult(Object data) {
        this.code = Constants.Result.Code.SUCCESS;
        this.data = data;
    }


    /**

     * 失败
     *
     * @return
     */
    public static FunctionResult failure() {
        FunctionResult result = new FunctionResult();
        result.setCode(Constants.Result.Code.ERROR);
        return result;
    }

    /**
     * 成功
     *
     * @param data
     * @return
     */
    public static FunctionResult ok(Object data) {
        return new FunctionResult(data);
    }

    public boolean isSuccess() {
        return Objects.equals(this.code, Constants.Result.Code.SUCCESS);
    }

    public String toString() {
        if(this.getData()==null){
            return "";
        }else {
            return JSONUtil.toJsonStr(this.getData());
        }
    }
}

统一一个返回值格式,方便后续的数据处理,并且将Result类的toString修改为data的JSON对象,OpenAI更理解JSON值的数据内容。

package com.lgf.warehouse.modules.ai.openai.functions;








import com.lgf.warehouse.core.chatgpt.entity.chat.ChatChoice;
import com.lgf.warehouse.core.chatgpt.entity.chat.FunctionCall;
import com.lgf.warehouse.core.chatgpt.entity.chat.Functions;
import com.lgf.warehouse.core.chatgpt.entity.chat.Message;
import com.lgf.warehouse.modules.ai.openai.functions.weather.service.WeatherApiService;
import jakarta.annotation.PostConstruct;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import java.util.*;




/**

 * 方法生产工厂
 */

@Slf4j
@Service
public class FunctionFactory {
    @Autowired
    private WeatherApiService weatherApiService;



    //    函数定义集合
    private Map<String, AbsFunctionService> functionBeanMap;

    private List<Functions> functions;

    /**
     * 注册服务到FunctionFactory
     */

    @PostConstruct()
    public void register() throws NoSuchMethodException {
        this.functionBeanMap = new HashMap<>();
        this.functions = new ArrayList<>();
        //注册天气服务
        this.registerFunction(this.weatherApiService);
        log.info(String.format("注册了%d个服务", this.functionBeanMap.size()));
    }



    private void registerFunction(AbsFunctionService functionService) throws NoSuchMethodException {
        this.functionBeanMap.put(functionService.getFunctionName(), functionService);
        this.functions.addAll(functionService.getFunctions());
    }

    // 定义公共的返回值
    public FunctionResult execute(String functionName, Map<String, Object> params) throws NoSuchMethodException {
        String[] sp = functionName.split("_");
        if (sp.length != 2) {
            throw new RuntimeException("方法名称不正确");
        }
        String beanName = sp[0];
        AbsFunctionService functionService = this.functionBeanMap.get(beanName);
        if (functionService == null) {
            throw new RuntimeException("没有找到对应的Bean");
        }
        //通过反射执行functionService中的functionName名称的方法
        Object result = functionService.execute(sp[1], params);
        if (result == null) {
            return FunctionResult.failure();
        } else {
            return FunctionResult.ok(result);
        }
    }

    /**
     * 包装Messages,在第二次请求的时候,将上下文信息都传入到OpenAI中,方便OpenAI理解,然后回复我
     * @param message
     * @param chatChoice
     * @param functionResult
     * @return
     */
    public List<Message> getMessages(Message message, ChatChoice chatChoice,FunctionResult functionResult) {
        FunctionCall fc = chatChoice.getMessage().getFunctionCall();
        FunctionCall functionCall = FunctionCall.builder()
                .arguments(fc.getArguments())
                .name(fc.getName())
                .build();
        //辅助消息,说明方法的参数信息
        Message message2 = Message.builder().role(Message.Role.ASSISTANT).content("方法参数").functionCall(functionCall).build();
        String content = functionResult.toString();
        //方法说明的消息,将方法的返回值告诉OpenAI
        Message message3 = Message.builder().role(Message.Role.FUNCTION).name(fc.getName()).content(content).build();
        //把问题串起来给OpenAI
        List<Message> messageList = Arrays.asList(message, message2, message3);
        return messageList;
    }


    /**
     * 获取方法集合
     *
     * @return
     */
    public List<Functions> getFunctions() {
        return this.functions;
    }

}

这个是接下来系统的核心处理方法,用来调用方法和创建方法。

package com.lgf.warehouse.modules.ai.openai.functions;








import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.util.Map;


/**

 * GPT函数的注释
 */

@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.METHOD,ElementType.PARAMETER})
public @interface FunctionAnnotation {

    /**
     * 描述
     * @return
     */
    public String describe() default "";


    /**
     * 枚举数据
     * @return
     */
    public String[] enums() default {};


    /**

     * 是否必须
     * @return
     */

    public boolean required() default false;
}

这个需要注意,必须在方法上定义,参数上尽量全部加上描述。

Demo

package com.lgf.warehouse.modules.ai.openai.functions.weather.service;





import cn.hutool.http.HttpUtil;
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import com.lgf.warehouse.modules.ai.openai.functions.AbsFunctionService;
import com.lgf.warehouse.modules.ai.openai.functions.FunctionAnnotation;
import com.lgf.warehouse.modules.ai.openai.functions.weather.vo.Weather;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;

import java.util.HashMap;
import java.util.Map;





@Service
public class WeatherApiService  extends AbsFunctionService {

    @Value("${functions.weather.api.url}")
    private String api;
    @Value("${functions.weather.api.key}")
    private String key;





    @FunctionAnnotation(describe = "获取天气预报")
    public Weather getWeatherByCity(@FunctionAnnotation(describe = "城市名称,如果是中文需要转换为拼音,比如洛阳,需要转换为LuoYang",required = true) String city,
                              @FunctionAnnotation(describe = "要查询的是哪一天的天气,今天是1,明天是2,后天是3,最多查询3天的天气,如果超出3天,则返回1",required = true,enums = {"1","2","3"}) Integer days) {
        //调用天气接口
        Map<String,Object> params=new HashMap<>();
        params.put("key",this.key);
        params.put("q",city);
        params.put("days",days);
        String result= HttpUtil.get(this.api,params);
        if(JSONUtil.isTypeJSONObject(result)){
            JSONObject response=JSONUtil.parseObj(result);
            if(response.get("current")!=null){
                JSONObject location=response.getJSONObject("location");
                JSONObject current=response.getJSONObject("current");
                JSONObject condition=current.getJSONObject("condition");
                // 解析天气接口返回的数据
                Weather weather= new Weather();
                weather.setCity(location.getStr("name"));
                weather.setRegion(location.getStr("region"));
                weather.setCountry(location.getStr("country"));

                weather.setDays(days);

                weather.setCondition(condition.getStr("text"));
                weather.setConditionIcon(condition.getStr("icon"));

                weather.setTemp(current.getDouble("temp_c"));
                return weather;
            }
        }
        return null;
    }

    @Override
    public String getFunctionName() {
        return "WeatherApiService";
    }

    @Override
    public Class getCla() {
        return this.getClass();
    }
}

天气接口用的是api.weatherapi.com/v1/current.… 免费接口,缺点是只能查询市级的天气,并且需要英文(中文地市就是拼音)。

结尾

夸一下GitHub的Copilot,特别是处理这种API接口的时候,调用一次接口,把返回值复制到Idea中后注释掉,然后写代码的时候会自动生成解析结构的代码,太爽了,减少这种码字工作。

© 版权声明
THE END
喜欢就支持一下吧
点赞0

Warning: mysqli_query(): (HY000/3): Error writing file '/tmp/MYuy2WcA' (Errcode: 28 - No space left on device) in /www/wwwroot/583.cn/wp-includes/class-wpdb.php on line 2345
admin的头像-五八三
评论 抢沙发
头像
欢迎您留下宝贵的见解!
提交
头像

昵称

图形验证码
取消
昵称代码图片