fix: gemini & vertex tools (#2053)
* fix: tool calling with gemini and vertex * chore: remove unnecessary comments * chore: fix doc type errors * feat: do not manually pass tool use list
This commit is contained in:
@@ -85,11 +85,18 @@ function M:parse_messages(opts)
|
|||||||
role = "function"
|
role = "function"
|
||||||
local ok, content = pcall(vim.json.decode, item.content)
|
local ok, content = pcall(vim.json.decode, item.content)
|
||||||
if not ok then content = item.content end
|
if not ok then content = item.content end
|
||||||
|
-- item.name here refers to the name of the tool that was called,
|
||||||
|
-- which is available in the tool_result content item prepared by llm.lua
|
||||||
|
local tool_name = item.name
|
||||||
|
if not tool_name then
|
||||||
|
-- Fallback, though item.name should ideally always be present for tool_result
|
||||||
|
tool_name = tool_id_to_name[item.tool_use_id]
|
||||||
|
end
|
||||||
table.insert(parts, {
|
table.insert(parts, {
|
||||||
functionResponse = {
|
functionResponse = {
|
||||||
name = tool_id_to_name[item.tool_use_id],
|
name = tool_name,
|
||||||
response = {
|
response = {
|
||||||
name = tool_id_to_name[item.tool_use_id],
|
name = tool_name, -- Gemini API requires the name in the response object as well
|
||||||
content = content,
|
content = content,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -130,14 +137,68 @@ function M:parse_messages(opts)
|
|||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
--- Prepares the main request body for Gemini-like APIs.
|
||||||
|
---@param provider_instance AvanteProviderFunctor The provider instance (self).
|
||||||
|
---@param prompt_opts AvantePromptOptions Prompt options including messages, tools, system_prompt.
|
||||||
|
---@param provider_conf table Provider configuration from config.lua (e.g., model, top-level temperature/max_tokens).
|
||||||
|
---@param request_body table Request-specific overrides, typically from provider_conf.request_config_overrides.
|
||||||
|
---@return table The fully constructed request body.
|
||||||
|
function M.prepare_request_body(provider_instance, prompt_opts, provider_conf, request_body)
|
||||||
|
request_body = vim.tbl_deep_extend("force", request_body, {
|
||||||
|
generationConfig = {
|
||||||
|
temperature = request_body.temperature,
|
||||||
|
maxOutputTokens = request_body.max_tokens,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
request_body.temperature = nil
|
||||||
|
request_body.max_tokens = nil
|
||||||
|
|
||||||
|
local disable_tools = provider_conf.disable_tools or false
|
||||||
|
|
||||||
|
if not disable_tools and prompt_opts.tools then
|
||||||
|
local function_declarations = {}
|
||||||
|
for _, tool in ipairs(prompt_opts.tools) do
|
||||||
|
table.insert(function_declarations, provider_instance:transform_to_function_declaration(tool))
|
||||||
|
end
|
||||||
|
|
||||||
|
if #function_declarations > 0 then
|
||||||
|
request_body.tools = {
|
||||||
|
{
|
||||||
|
functionDeclarations = function_declarations,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
return vim.tbl_deep_extend("force", {}, provider_instance:parse_messages(prompt_opts), request_body)
|
||||||
|
end
|
||||||
|
|
||||||
function M:parse_response(ctx, data_stream, _, opts)
|
function M:parse_response(ctx, data_stream, _, opts)
|
||||||
local ok, json = pcall(vim.json.decode, data_stream)
|
local ok, json = pcall(vim.json.decode, data_stream)
|
||||||
if not ok then opts.on_stop({ reason = "error", error = json }) end
|
if not ok then
|
||||||
|
opts.on_stop({ reason = "error", error = "Failed to parse JSON response: " .. tostring(json) })
|
||||||
|
return
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Handle prompt feedback first, as it might indicate an overall issue with the prompt
|
||||||
|
if json.promptFeedback and json.promptFeedback.blockReason then
|
||||||
|
local feedback = json.promptFeedback
|
||||||
|
OpenAI:finish_pending_messages(ctx, opts) -- Ensure any pending messages are cleared
|
||||||
|
opts.on_stop({
|
||||||
|
reason = "error",
|
||||||
|
error = "Prompt blocked or filtered. Reason: " .. feedback.blockReason,
|
||||||
|
details = feedback,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
end
|
||||||
|
|
||||||
if json.candidates and #json.candidates > 0 then
|
if json.candidates and #json.candidates > 0 then
|
||||||
local candidate = json.candidates[1]
|
local candidate = json.candidates[1]
|
||||||
---@type AvanteLLMToolUse[]
|
---@type AvanteLLMToolUse[]
|
||||||
local tool_use_list = {}
|
local tool_use_list = {}
|
||||||
if candidate.content.parts ~= nil then
|
|
||||||
|
-- Check if candidate.content and candidate.content.parts exist before iterating
|
||||||
|
if candidate.content and candidate.content.parts then
|
||||||
for _, part in ipairs(candidate.content.parts) do
|
for _, part in ipairs(candidate.content.parts) do
|
||||||
if part.text then
|
if part.text then
|
||||||
if opts.on_chunk then opts.on_chunk(part.text) end
|
if opts.on_chunk then opts.on_chunk(part.text) end
|
||||||
@@ -155,51 +216,56 @@ function M:parse_response(ctx, data_stream, _, opts)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
if candidate.finishReason and candidate.finishReason == "STOP" then
|
|
||||||
|
-- Check for finishReason to determine if this candidate's stream is done.
|
||||||
|
if candidate.finishReason then
|
||||||
OpenAI:finish_pending_messages(ctx, opts)
|
OpenAI:finish_pending_messages(ctx, opts)
|
||||||
if #tool_use_list > 0 then
|
local reason_str = candidate.finishReason
|
||||||
opts.on_stop({ reason = "tool_use", tool_use_list = tool_use_list })
|
local stop_details = { finish_reason = reason_str }
|
||||||
else
|
|
||||||
opts.on_stop({ reason = "complete" })
|
if reason_str == "TOOL_CODE" then
|
||||||
|
-- Model indicates a tool-related stop.
|
||||||
|
-- The tool_use list is added to the table in llm.lua
|
||||||
|
opts.on_stop(vim.tbl_deep_extend("force", { reason = "tool_use" }, stop_details))
|
||||||
|
elseif reason_str == "STOP" then
|
||||||
|
if #tool_use_list > 0 then
|
||||||
|
-- Natural stop, but tools were found in this final chunk.
|
||||||
|
opts.on_stop(vim.tbl_deep_extend("force", { reason = "tool_use" }, stop_details))
|
||||||
|
else
|
||||||
|
-- Natural stop, no tools in this final chunk.
|
||||||
|
-- llm.lua will check its accumulated tools if tool_choice was active.
|
||||||
|
opts.on_stop(vim.tbl_deep_extend("force", { reason = "complete" }, stop_details))
|
||||||
|
end
|
||||||
|
elseif reason_str == "MAX_TOKENS" then
|
||||||
|
opts.on_stop(vim.tbl_deep_extend("force", { reason = "max_tokens" }, stop_details))
|
||||||
|
elseif reason_str == "SAFETY" or reason_str == "RECITATION" then
|
||||||
|
opts.on_stop(
|
||||||
|
vim.tbl_deep_extend(
|
||||||
|
"force",
|
||||||
|
{ reason = "error", error = "Generation stopped: " .. reason_str },
|
||||||
|
stop_details
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else -- OTHER, FINISH_REASON_UNSPECIFIED, or any other unhandled reason.
|
||||||
|
opts.on_stop(
|
||||||
|
vim.tbl_deep_extend(
|
||||||
|
"force",
|
||||||
|
{ reason = "error", error = "Generation stopped with unhandled reason: " .. reason_str },
|
||||||
|
stop_details
|
||||||
|
)
|
||||||
|
)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
else
|
-- If no finishReason, it's an intermediate chunk; do not call on_stop.
|
||||||
OpenAI:finish_pending_messages(ctx, opts)
|
|
||||||
opts.on_stop({ reason = "complete" })
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function M:parse_curl_args(prompt_opts)
|
function M:parse_curl_args(prompt_opts)
|
||||||
local provider_conf, request_body = Providers.parse_config(self)
|
local provider_conf, request_body = Providers.parse_config(self)
|
||||||
local disable_tools = provider_conf.disable_tools or false
|
|
||||||
|
|
||||||
request_body = vim.tbl_deep_extend("force", request_body, {
|
local api_key = self:parse_api_key()
|
||||||
generationConfig = {
|
|
||||||
temperature = request_body.temperature,
|
|
||||||
maxOutputTokens = request_body.max_tokens,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
request_body.temperature = nil
|
|
||||||
request_body.max_tokens = nil
|
|
||||||
|
|
||||||
local api_key = self.parse_api_key()
|
|
||||||
if api_key == nil then error("Cannot get the gemini api key!") end
|
if api_key == nil then error("Cannot get the gemini api key!") end
|
||||||
|
|
||||||
local function_declarations = {}
|
|
||||||
if not disable_tools and prompt_opts.tools then
|
|
||||||
for _, tool in ipairs(prompt_opts.tools) do
|
|
||||||
table.insert(function_declarations, self:transform_to_function_declaration(tool))
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
if #function_declarations > 0 then
|
|
||||||
request_body.tools = {
|
|
||||||
{
|
|
||||||
functionDeclarations = function_declarations,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
end
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
url = Utils.url_join(
|
url = Utils.url_join(
|
||||||
provider_conf.endpoint,
|
provider_conf.endpoint,
|
||||||
@@ -208,7 +274,7 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
proxy = provider_conf.proxy,
|
proxy = provider_conf.proxy,
|
||||||
insecure = provider_conf.allow_insecure,
|
insecure = provider_conf.allow_insecure,
|
||||||
headers = { ["Content-Type"] = "application/json" },
|
headers = { ["Content-Type"] = "application/json" },
|
||||||
body = vim.tbl_deep_extend("force", {}, self:parse_messages(prompt_opts), request_body),
|
body = M.prepare_request_body(self, prompt_opts, provider_conf, request_body),
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ M.role_map = {
|
|||||||
M.is_disable_stream = Gemini.is_disable_stream
|
M.is_disable_stream = Gemini.is_disable_stream
|
||||||
M.parse_messages = Gemini.parse_messages
|
M.parse_messages = Gemini.parse_messages
|
||||||
M.parse_response = Gemini.parse_response
|
M.parse_response = Gemini.parse_response
|
||||||
|
M.transform_to_function_declaration = Gemini.transform_to_function_declaration
|
||||||
|
|
||||||
local function execute_command(command)
|
local function execute_command(command)
|
||||||
local handle = io.popen(command)
|
local handle = io.popen(command)
|
||||||
@@ -34,23 +35,17 @@ end
|
|||||||
|
|
||||||
function M:parse_curl_args(prompt_opts)
|
function M:parse_curl_args(prompt_opts)
|
||||||
local provider_conf, request_body = P.parse_config(self)
|
local provider_conf, request_body = P.parse_config(self)
|
||||||
|
|
||||||
local location = vim.fn.getenv("LOCATION")
|
local location = vim.fn.getenv("LOCATION")
|
||||||
local project_id = vim.fn.getenv("PROJECT_ID")
|
local project_id = vim.fn.getenv("PROJECT_ID")
|
||||||
local model_id = provider_conf.model or "default-model-id"
|
local model_id = provider_conf.model or "default-model-id"
|
||||||
|
|
||||||
if location == nil or location == vim.NIL then location = "default-location" end
|
if location == nil or location == vim.NIL then location = "default-location" end
|
||||||
if project_id == nil or project_id == vim.NIL then project_id = "default-project-id" end
|
if project_id == nil or project_id == vim.NIL then project_id = "default-project-id" end
|
||||||
local url = provider_conf.endpoint:gsub("LOCATION", location):gsub("PROJECT_ID", project_id)
|
|
||||||
|
|
||||||
|
local url = provider_conf.endpoint:gsub("LOCATION", location):gsub("PROJECT_ID", project_id)
|
||||||
url = string.format("%s/%s:streamGenerateContent?alt=sse", url, model_id)
|
url = string.format("%s/%s:streamGenerateContent?alt=sse", url, model_id)
|
||||||
|
|
||||||
request_body = vim.tbl_deep_extend("force", request_body, {
|
|
||||||
generationConfig = {
|
|
||||||
temperature = request_body.temperature,
|
|
||||||
maxOutputTokens = request_body.max_tokens,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
request_body.temperature = nil
|
|
||||||
request_body.max_tokens = nil
|
|
||||||
local bearer_token = M.parse_api_key()
|
local bearer_token = M.parse_api_key()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -61,7 +56,7 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
},
|
},
|
||||||
proxy = provider_conf.proxy,
|
proxy = provider_conf.proxy,
|
||||||
insecure = provider_conf.allow_insecure,
|
insecure = provider_conf.allow_insecure,
|
||||||
body = vim.tbl_deep_extend("force", {}, self:parse_messages(prompt_opts), request_body),
|
body = Gemini.prepare_request_body(self, prompt_opts, provider_conf, request_body),
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user