From 87c4c6b4937d1884960759aba4a0e42645688f2f Mon Sep 17 00:00:00 2001 From: kernitus <2789734+kernitus@users.noreply.github.com> Date: Tue, 20 May 2025 15:06:25 +0100 Subject: [PATCH] fix: gemini & vertex tools (#2053) * fix: tool calling with gemini and vertex * chore: remove unnecessary comments * chore: fix doc type errors * feat: do not manually pass tool use list --- lua/avante/providers/gemini.lua | 144 +++++++++++++++++++++++--------- lua/avante/providers/vertex.lua | 15 ++-- 2 files changed, 110 insertions(+), 49 deletions(-) diff --git a/lua/avante/providers/gemini.lua b/lua/avante/providers/gemini.lua index 1567d4e..0450252 100644 --- a/lua/avante/providers/gemini.lua +++ b/lua/avante/providers/gemini.lua @@ -85,11 +85,18 @@ function M:parse_messages(opts) role = "function" local ok, content = pcall(vim.json.decode, item.content) if not ok then content = item.content end + -- item.name here refers to the name of the tool that was called, + -- which is available in the tool_result content item prepared by llm.lua + local tool_name = item.name + if not tool_name then + -- Fallback, though item.name should ideally always be present for tool_result + tool_name = tool_id_to_name[item.tool_use_id] + end table.insert(parts, { functionResponse = { - name = tool_id_to_name[item.tool_use_id], + name = tool_name, response = { - name = tool_id_to_name[item.tool_use_id], + name = tool_name, -- Gemini API requires the name in the response object as well content = content, }, }, @@ -130,14 +137,68 @@ function M:parse_messages(opts) } end +--- Prepares the main request body for Gemini-like APIs. +---@param provider_instance AvanteProviderFunctor The provider instance (self). +---@param prompt_opts AvantePromptOptions Prompt options including messages, tools, system_prompt. +---@param provider_conf table Provider configuration from config.lua (e.g., model, top-level temperature/max_tokens). +---@param request_body table Request-specific overrides, typically from provider_conf.request_config_overrides. +---@return table The fully constructed request body. +function M.prepare_request_body(provider_instance, prompt_opts, provider_conf, request_body) + request_body = vim.tbl_deep_extend("force", request_body, { + generationConfig = { + temperature = request_body.temperature, + maxOutputTokens = request_body.max_tokens, + }, + }) + request_body.temperature = nil + request_body.max_tokens = nil + + local disable_tools = provider_conf.disable_tools or false + + if not disable_tools and prompt_opts.tools then + local function_declarations = {} + for _, tool in ipairs(prompt_opts.tools) do + table.insert(function_declarations, provider_instance:transform_to_function_declaration(tool)) + end + + if #function_declarations > 0 then + request_body.tools = { + { + functionDeclarations = function_declarations, + }, + } + end + end + + return vim.tbl_deep_extend("force", {}, provider_instance:parse_messages(prompt_opts), request_body) +end + function M:parse_response(ctx, data_stream, _, opts) local ok, json = pcall(vim.json.decode, data_stream) - if not ok then opts.on_stop({ reason = "error", error = json }) end + if not ok then + opts.on_stop({ reason = "error", error = "Failed to parse JSON response: " .. tostring(json) }) + return + end + + -- Handle prompt feedback first, as it might indicate an overall issue with the prompt + if json.promptFeedback and json.promptFeedback.blockReason then + local feedback = json.promptFeedback + OpenAI:finish_pending_messages(ctx, opts) -- Ensure any pending messages are cleared + opts.on_stop({ + reason = "error", + error = "Prompt blocked or filtered. Reason: " .. feedback.blockReason, + details = feedback, + }) + return + end + if json.candidates and #json.candidates > 0 then local candidate = json.candidates[1] ---@type AvanteLLMToolUse[] local tool_use_list = {} - if candidate.content.parts ~= nil then + + -- Check if candidate.content and candidate.content.parts exist before iterating + if candidate.content and candidate.content.parts then for _, part in ipairs(candidate.content.parts) do if part.text then if opts.on_chunk then opts.on_chunk(part.text) end @@ -155,51 +216,56 @@ function M:parse_response(ctx, data_stream, _, opts) end end end - if candidate.finishReason and candidate.finishReason == "STOP" then + + -- Check for finishReason to determine if this candidate's stream is done. + if candidate.finishReason then OpenAI:finish_pending_messages(ctx, opts) - if #tool_use_list > 0 then - opts.on_stop({ reason = "tool_use", tool_use_list = tool_use_list }) - else - opts.on_stop({ reason = "complete" }) + local reason_str = candidate.finishReason + local stop_details = { finish_reason = reason_str } + + if reason_str == "TOOL_CODE" then + -- Model indicates a tool-related stop. + -- The tool_use list is added to the table in llm.lua + opts.on_stop(vim.tbl_deep_extend("force", { reason = "tool_use" }, stop_details)) + elseif reason_str == "STOP" then + if #tool_use_list > 0 then + -- Natural stop, but tools were found in this final chunk. + opts.on_stop(vim.tbl_deep_extend("force", { reason = "tool_use" }, stop_details)) + else + -- Natural stop, no tools in this final chunk. + -- llm.lua will check its accumulated tools if tool_choice was active. + opts.on_stop(vim.tbl_deep_extend("force", { reason = "complete" }, stop_details)) + end + elseif reason_str == "MAX_TOKENS" then + opts.on_stop(vim.tbl_deep_extend("force", { reason = "max_tokens" }, stop_details)) + elseif reason_str == "SAFETY" or reason_str == "RECITATION" then + opts.on_stop( + vim.tbl_deep_extend( + "force", + { reason = "error", error = "Generation stopped: " .. reason_str }, + stop_details + ) + ) + else -- OTHER, FINISH_REASON_UNSPECIFIED, or any other unhandled reason. + opts.on_stop( + vim.tbl_deep_extend( + "force", + { reason = "error", error = "Generation stopped with unhandled reason: " .. reason_str }, + stop_details + ) + ) end end - else - OpenAI:finish_pending_messages(ctx, opts) - opts.on_stop({ reason = "complete" }) + -- If no finishReason, it's an intermediate chunk; do not call on_stop. end end function M:parse_curl_args(prompt_opts) local provider_conf, request_body = Providers.parse_config(self) - local disable_tools = provider_conf.disable_tools or false - request_body = vim.tbl_deep_extend("force", request_body, { - generationConfig = { - temperature = request_body.temperature, - maxOutputTokens = request_body.max_tokens, - }, - }) - request_body.temperature = nil - request_body.max_tokens = nil - - local api_key = self.parse_api_key() + local api_key = self:parse_api_key() if api_key == nil then error("Cannot get the gemini api key!") end - local function_declarations = {} - if not disable_tools and prompt_opts.tools then - for _, tool in ipairs(prompt_opts.tools) do - table.insert(function_declarations, self:transform_to_function_declaration(tool)) - end - end - - if #function_declarations > 0 then - request_body.tools = { - { - functionDeclarations = function_declarations, - }, - } - end - return { url = Utils.url_join( provider_conf.endpoint, @@ -208,7 +274,7 @@ function M:parse_curl_args(prompt_opts) proxy = provider_conf.proxy, insecure = provider_conf.allow_insecure, headers = { ["Content-Type"] = "application/json" }, - body = vim.tbl_deep_extend("force", {}, self:parse_messages(prompt_opts), request_body), + body = M.prepare_request_body(self, prompt_opts, provider_conf, request_body), } end diff --git a/lua/avante/providers/vertex.lua b/lua/avante/providers/vertex.lua index ccd93af..5119553 100644 --- a/lua/avante/providers/vertex.lua +++ b/lua/avante/providers/vertex.lua @@ -14,6 +14,7 @@ M.role_map = { M.is_disable_stream = Gemini.is_disable_stream M.parse_messages = Gemini.parse_messages M.parse_response = Gemini.parse_response +M.transform_to_function_declaration = Gemini.transform_to_function_declaration local function execute_command(command) local handle = io.popen(command) @@ -34,23 +35,17 @@ end function M:parse_curl_args(prompt_opts) local provider_conf, request_body = P.parse_config(self) + local location = vim.fn.getenv("LOCATION") local project_id = vim.fn.getenv("PROJECT_ID") local model_id = provider_conf.model or "default-model-id" + if location == nil or location == vim.NIL then location = "default-location" end if project_id == nil or project_id == vim.NIL then project_id = "default-project-id" end - local url = provider_conf.endpoint:gsub("LOCATION", location):gsub("PROJECT_ID", project_id) + local url = provider_conf.endpoint:gsub("LOCATION", location):gsub("PROJECT_ID", project_id) url = string.format("%s/%s:streamGenerateContent?alt=sse", url, model_id) - request_body = vim.tbl_deep_extend("force", request_body, { - generationConfig = { - temperature = request_body.temperature, - maxOutputTokens = request_body.max_tokens, - }, - }) - request_body.temperature = nil - request_body.max_tokens = nil local bearer_token = M.parse_api_key() return { @@ -61,7 +56,7 @@ function M:parse_curl_args(prompt_opts) }, proxy = provider_conf.proxy, insecure = provider_conf.allow_insecure, - body = vim.tbl_deep_extend("force", {}, self:parse_messages(prompt_opts), request_body), + body = Gemini.prepare_request_body(self, prompt_opts, provider_conf, request_body), } end