feat: streaming json parser (#1883)
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
local Utils = require("avante.utils")
|
||||
local Config = require("avante.config")
|
||||
local Clipboard = require("avante.clipboard")
|
||||
local P = require("avante.providers")
|
||||
local Providers = require("avante.providers")
|
||||
local StreamingJsonParser = require("avante.utils.streaming_json_parser")
|
||||
|
||||
---@class AvanteProviderFunctor
|
||||
local M = {}
|
||||
@@ -73,7 +74,7 @@ end
|
||||
|
||||
function M:parse_messages(opts)
|
||||
local messages = {}
|
||||
local provider_conf, _ = P.parse_config(self)
|
||||
local provider_conf, _ = Providers.parse_config(self)
|
||||
|
||||
if self.is_reasoning_model(provider_conf.model) then
|
||||
table.insert(messages, { role = "developer", content = opts.system_prompt })
|
||||
@@ -224,18 +225,37 @@ function M:parse_response(ctx, data_stream, _, opts)
|
||||
ctx.last_think_content = choice.delta.reasoning
|
||||
opts.on_chunk(choice.delta.reasoning)
|
||||
elseif choice.delta.tool_calls and choice.delta.tool_calls ~= vim.NIL then
|
||||
local tool_call = choice.delta.tool_calls[1]
|
||||
if not ctx.tool_use_list then ctx.tool_use_list = {} end
|
||||
if not ctx.tool_use_list[tool_call.index + 1] then
|
||||
local tool_use = {
|
||||
name = tool_call["function"].name,
|
||||
id = tool_call.id,
|
||||
input_json = "",
|
||||
}
|
||||
ctx.tool_use_list[tool_call.index + 1] = tool_use
|
||||
else
|
||||
local tool_use = ctx.tool_use_list[tool_call.index + 1]
|
||||
tool_use.input_json = tool_use.input_json .. tool_call["function"].arguments
|
||||
for _, tool_call in ipairs(choice.delta.tool_calls) do
|
||||
if not ctx.tool_use_list then ctx.tool_use_list = {} end
|
||||
if not ctx.tool_use_list[tool_call.index + 1] then
|
||||
local tool_use = {
|
||||
name = tool_call["function"].name,
|
||||
id = tool_call.id,
|
||||
input_json = "",
|
||||
}
|
||||
ctx.tool_use_list[tool_call.index + 1] = tool_use
|
||||
if opts.on_partial_tool_use then
|
||||
opts.on_partial_tool_use({
|
||||
name = tool_call["function"].name,
|
||||
id = tool_call.id,
|
||||
partial_json = {},
|
||||
state = "generating",
|
||||
})
|
||||
end
|
||||
else
|
||||
local tool_use = ctx.tool_use_list[tool_call.index + 1]
|
||||
tool_use.input_json = tool_use.input_json .. tool_call["function"].arguments
|
||||
if opts.on_partial_tool_use then
|
||||
local parser = StreamingJsonParser:new()
|
||||
local partial_json = parser:parse(tool_use.input_json)
|
||||
opts.on_partial_tool_use({
|
||||
name = tool_call["function"].name,
|
||||
id = tool_call.id,
|
||||
partial_json = partial_json or {},
|
||||
state = "generating",
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
elseif choice.delta.content then
|
||||
if
|
||||
@@ -271,7 +291,7 @@ function M:parse_response_without_stream(data, _, opts)
|
||||
end
|
||||
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
local provider_conf, request_body = P.parse_config(self)
|
||||
local provider_conf, request_body = Providers.parse_config(self)
|
||||
local disable_tools = provider_conf.disable_tools or false
|
||||
|
||||
local headers = {
|
||||
@@ -284,7 +304,7 @@ function M:parse_curl_args(prompt_opts)
|
||||
end
|
||||
end
|
||||
|
||||
if P.env.require_api_key(provider_conf) then
|
||||
if Providers.env.require_api_key(provider_conf) then
|
||||
local api_key = self.parse_api_key()
|
||||
if api_key == nil then
|
||||
error(Config.provider .. " API key is not set, please set it in your environment variable or config file")
|
||||
|
||||
Reference in New Issue
Block a user