fix: multiple tool use histories and disable tools (#1185)
This commit is contained in:
@@ -17,7 +17,7 @@ M.parse_messages = O.parse_messages
|
||||
M.parse_response = O.parse_response
|
||||
M.parse_response_without_stream = O.parse_response_without_stream
|
||||
|
||||
M.parse_curl_args = function(provider, code_opts)
|
||||
M.parse_curl_args = function(provider, prompt_opts)
|
||||
local base, body_opts = P.parse_config(provider)
|
||||
|
||||
local headers = {
|
||||
@@ -40,7 +40,7 @@ M.parse_curl_args = function(provider, code_opts)
|
||||
insecure = base.allow_insecure,
|
||||
headers = headers,
|
||||
body = vim.tbl_deep_extend("force", {
|
||||
messages = M.parse_messages(code_opts),
|
||||
messages = M.parse_messages(prompt_opts),
|
||||
stream = true,
|
||||
}, body_opts),
|
||||
}
|
||||
|
||||
@@ -112,38 +112,42 @@ M.parse_messages = function(opts)
|
||||
messages[#messages].content = message_content
|
||||
end
|
||||
|
||||
if opts.tool_use then
|
||||
local msg = {
|
||||
role = "assistant",
|
||||
content = {},
|
||||
}
|
||||
if opts.response_content then
|
||||
msg.content[#msg.content + 1] = {
|
||||
type = "text",
|
||||
text = opts.response_content,
|
||||
}
|
||||
end
|
||||
msg.content[#msg.content + 1] = {
|
||||
type = "tool_use",
|
||||
id = opts.tool_use.id,
|
||||
name = opts.tool_use.name,
|
||||
input = vim.json.decode(opts.tool_use.input_json),
|
||||
}
|
||||
messages[#messages + 1] = msg
|
||||
end
|
||||
if opts.tool_histories then
|
||||
for _, tool_history in ipairs(opts.tool_histories) do
|
||||
if tool_history.tool_use then
|
||||
local msg = {
|
||||
role = "assistant",
|
||||
content = {},
|
||||
}
|
||||
if tool_history.response_content then
|
||||
msg.content[#msg.content + 1] = {
|
||||
type = "text",
|
||||
text = tool_history.response_content,
|
||||
}
|
||||
end
|
||||
msg.content[#msg.content + 1] = {
|
||||
type = "tool_use",
|
||||
id = tool_history.tool_use.id,
|
||||
name = tool_history.tool_use.name,
|
||||
input = vim.json.decode(tool_history.tool_use.input_json),
|
||||
}
|
||||
messages[#messages + 1] = msg
|
||||
end
|
||||
|
||||
if opts.tool_result then
|
||||
messages[#messages + 1] = {
|
||||
role = "user",
|
||||
content = {
|
||||
{
|
||||
type = "tool_result",
|
||||
tool_use_id = opts.tool_result.tool_use_id,
|
||||
content = opts.tool_result.content,
|
||||
is_error = opts.tool_result.is_error,
|
||||
},
|
||||
},
|
||||
}
|
||||
if tool_history.tool_result then
|
||||
messages[#messages + 1] = {
|
||||
role = "user",
|
||||
content = {
|
||||
{
|
||||
type = "tool_result",
|
||||
tool_use_id = tool_history.tool_result.tool_use_id,
|
||||
content = tool_history.tool_result.content,
|
||||
is_error = tool_history.tool_result.is_error,
|
||||
},
|
||||
},
|
||||
}
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return messages
|
||||
|
||||
@@ -69,7 +69,7 @@ M.parse_stream_data = function(data, opts)
|
||||
end
|
||||
end
|
||||
|
||||
M.parse_curl_args = function(provider, code_opts)
|
||||
M.parse_curl_args = function(provider, prompt_opts)
|
||||
local base, body_opts = P.parse_config(provider)
|
||||
|
||||
local headers = {
|
||||
@@ -92,7 +92,7 @@ M.parse_curl_args = function(provider, code_opts)
|
||||
body = vim.tbl_deep_extend("force", {
|
||||
model = base.model,
|
||||
stream = true,
|
||||
}, M.parse_messages(code_opts), body_opts),
|
||||
}, M.parse_messages(prompt_opts), body_opts),
|
||||
}
|
||||
end
|
||||
|
||||
|
||||
@@ -241,7 +241,7 @@ end
|
||||
|
||||
M.parse_response = OpenAI.parse_response
|
||||
|
||||
M.parse_curl_args = function(provider, code_opts)
|
||||
M.parse_curl_args = function(provider, prompt_opts)
|
||||
-- refresh token synchronously, only if it has expired
|
||||
-- (this should rarely happen, as we refresh the token in the background)
|
||||
H.refresh_token(false, false)
|
||||
@@ -249,8 +249,8 @@ M.parse_curl_args = function(provider, code_opts)
|
||||
local base, body_opts = P.parse_config(provider)
|
||||
|
||||
local tools = {}
|
||||
if code_opts.tools then
|
||||
for _, tool in ipairs(code_opts.tools) do
|
||||
if prompt_opts.tools then
|
||||
for _, tool in ipairs(prompt_opts.tools) do
|
||||
table.insert(tools, OpenAI.transform_tool(tool))
|
||||
end
|
||||
end
|
||||
@@ -268,7 +268,7 @@ M.parse_curl_args = function(provider, code_opts)
|
||||
},
|
||||
body = vim.tbl_deep_extend("force", {
|
||||
model = base.model,
|
||||
messages = M.parse_messages(code_opts),
|
||||
messages = M.parse_messages(prompt_opts),
|
||||
stream = true,
|
||||
tools = tools,
|
||||
}, body_opts),
|
||||
|
||||
@@ -81,7 +81,7 @@ M.parse_response = function(ctx, data_stream, _, opts)
|
||||
end
|
||||
end
|
||||
|
||||
M.parse_curl_args = function(provider, code_opts)
|
||||
M.parse_curl_args = function(provider, prompt_opts)
|
||||
local base, body_opts = P.parse_config(provider)
|
||||
|
||||
body_opts = vim.tbl_deep_extend("force", body_opts, {
|
||||
@@ -101,7 +101,7 @@ M.parse_curl_args = function(provider, code_opts)
|
||||
proxy = base.proxy,
|
||||
insecure = base.allow_insecure,
|
||||
headers = { ["Content-Type"] = "application/json" },
|
||||
body = vim.tbl_deep_extend("force", {}, M.parse_messages(code_opts), body_opts),
|
||||
body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), body_opts),
|
||||
}
|
||||
end
|
||||
|
||||
|
||||
@@ -30,9 +30,7 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil }
|
||||
---@field messages AvanteLLMMessage[]
|
||||
---@field image_paths? string[]
|
||||
---@field tools? AvanteLLMTool[]
|
||||
---@field tool_result? AvanteLLMToolResult
|
||||
---@field tool_use? AvanteLLMToolUse
|
||||
---@field response_content? string
|
||||
---@field tool_histories? AvanteLLMToolHistory[]
|
||||
---
|
||||
---@class AvanteGeminiMessage
|
||||
---@field role "user"
|
||||
@@ -43,7 +41,7 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil }
|
||||
---@alias AvanteMessagesParser fun(opts: AvantePromptOptions): AvanteChatMessage[]
|
||||
---
|
||||
---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table<string, any> | string, headers: table<string, string>, rawArgs: string[] | nil}
|
||||
---@alias AvanteCurlArgsParser fun(opts: AvanteProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor, code_opts: AvantePromptOptions): AvanteCurlOutput
|
||||
---@alias AvanteCurlArgsParser fun(opts: AvanteProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor, prompt_opts: AvantePromptOptions): AvanteCurlOutput
|
||||
---
|
||||
---@class ResponseParser
|
||||
---@field on_start AvanteLLMStartCallback
|
||||
@@ -60,6 +58,7 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil }
|
||||
---@field allow_insecure? boolean
|
||||
---@field api_key_name? string
|
||||
---@field _shellenv? string
|
||||
---@field disable_tools? boolean
|
||||
---
|
||||
---@class AvanteSupportedProvider: AvanteDefaultBaseProvider
|
||||
---@field __inherited_from? string
|
||||
@@ -382,26 +381,31 @@ function M.refresh(provider)
|
||||
end
|
||||
|
||||
---@param opts AvanteProvider | AvanteSupportedProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor
|
||||
---@return AvanteDefaultBaseProvider, table<string, any>
|
||||
---@return AvanteDefaultBaseProvider provider_opts
|
||||
---@return table<string, any> request_body
|
||||
M.parse_config = function(opts)
|
||||
---@type AvanteDefaultBaseProvider
|
||||
local s1 = {}
|
||||
local provider_opts = {}
|
||||
---@type table<string, any>
|
||||
local s2 = {}
|
||||
local request_body = {}
|
||||
|
||||
for key, value in pairs(opts) do
|
||||
if vim.tbl_contains(Config.BASE_PROVIDER_KEYS, key) then
|
||||
s1[key] = value
|
||||
provider_opts[key] = value
|
||||
else
|
||||
s2[key] = value
|
||||
request_body[key] = value
|
||||
end
|
||||
end
|
||||
|
||||
return s1,
|
||||
vim.iter(s2):filter(function(_, v) return type(v) ~= "function" end):fold({}, function(acc, k, v)
|
||||
request_body = vim
|
||||
.iter(request_body)
|
||||
:filter(function(_, v) return type(v) ~= "function" end)
|
||||
:fold({}, function(acc, k, v)
|
||||
acc[k] = v
|
||||
return acc
|
||||
end)
|
||||
|
||||
return provider_opts, request_body
|
||||
end
|
||||
|
||||
---@private
|
||||
|
||||
@@ -167,26 +167,28 @@ M.parse_messages = function(opts)
|
||||
table.insert(final_messages, { role = M.role_map[role] or role, content = message.content })
|
||||
end)
|
||||
|
||||
if opts.tool_result then
|
||||
table.insert(final_messages, {
|
||||
role = M.role_map["assistant"],
|
||||
tool_calls = {
|
||||
{
|
||||
id = opts.tool_use.id,
|
||||
type = "function",
|
||||
["function"] = {
|
||||
name = opts.tool_use.name,
|
||||
arguments = opts.tool_use.input_json,
|
||||
if opts.tool_histories then
|
||||
for _, tool_history in ipairs(opts.tool_histories) do
|
||||
table.insert(final_messages, {
|
||||
role = M.role_map["assistant"],
|
||||
tool_calls = {
|
||||
{
|
||||
id = tool_history.tool_use.id,
|
||||
type = "function",
|
||||
["function"] = {
|
||||
name = tool_history.tool_use.name,
|
||||
arguments = tool_history.tool_use.input_json,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
local result_content = opts.tool_result.content or ""
|
||||
table.insert(final_messages, {
|
||||
role = "tool",
|
||||
tool_call_id = opts.tool_result.tool_use_id,
|
||||
content = opts.tool_result.is_error and "Error: " .. result_content or result_content,
|
||||
})
|
||||
})
|
||||
local result_content = tool_history.tool_result.content or ""
|
||||
table.insert(final_messages, {
|
||||
role = "tool",
|
||||
tool_call_id = tool_history.tool_result.tool_use_id,
|
||||
content = tool_history.tool_result.is_error and "Error: " .. result_content or result_content,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
return final_messages
|
||||
@@ -269,8 +271,9 @@ M.parse_response_without_stream = function(data, _, opts)
|
||||
end
|
||||
end
|
||||
|
||||
M.parse_curl_args = function(provider, code_opts)
|
||||
M.parse_curl_args = function(provider, prompt_opts)
|
||||
local base, body_opts = P.parse_config(provider)
|
||||
local disable_tools = base.disable_tools or false
|
||||
|
||||
local headers = {
|
||||
["Content-Type"] = "application/json",
|
||||
@@ -298,9 +301,10 @@ M.parse_curl_args = function(provider, code_opts)
|
||||
body_opts.temperature = 1
|
||||
end
|
||||
|
||||
local tools = {}
|
||||
if code_opts.tools then
|
||||
for _, tool in ipairs(code_opts.tools) do
|
||||
local tools = nil
|
||||
if not disable_tools and prompt_opts.tools then
|
||||
tools = {}
|
||||
for _, tool in ipairs(prompt_opts.tools) do
|
||||
table.insert(tools, M.transform_tool(tool))
|
||||
end
|
||||
end
|
||||
@@ -315,7 +319,7 @@ M.parse_curl_args = function(provider, code_opts)
|
||||
headers = headers,
|
||||
body = vim.tbl_deep_extend("force", {
|
||||
model = base.model,
|
||||
messages = M.parse_messages(code_opts),
|
||||
messages = M.parse_messages(prompt_opts),
|
||||
stream = stream,
|
||||
tools = tools,
|
||||
}, body_opts),
|
||||
|
||||
@@ -31,7 +31,7 @@ M.parse_api_key = function()
|
||||
return direct_output
|
||||
end
|
||||
|
||||
M.parse_curl_args = function(provider, code_opts)
|
||||
M.parse_curl_args = function(provider, prompt_opts)
|
||||
local base, body_opts = P.parse_config(provider)
|
||||
local location = vim.fn.getenv("LOCATION") or "default-location"
|
||||
local project_id = vim.fn.getenv("PROJECT_ID") or "default-project-id"
|
||||
@@ -58,7 +58,7 @@ M.parse_curl_args = function(provider, code_opts)
|
||||
},
|
||||
proxy = base.proxy,
|
||||
insecure = base.allow_insecure,
|
||||
body = vim.tbl_deep_extend("force", {}, M.parse_messages(code_opts), body_opts),
|
||||
body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), body_opts),
|
||||
}
|
||||
end
|
||||
|
||||
|
||||
Reference in New Issue
Block a user