Files
avante.nvim/lua/avante/providers/gemini.lua
2025-04-30 14:07:59 +08:00

216 lines
6.7 KiB
Lua

local Utils = require("avante.utils")
local Providers = require("avante.providers")
local Clipboard = require("avante.clipboard")
local OpenAI = require("avante.providers").openai
---@class AvanteProviderFunctor
local M = {}
M.api_key_name = "GEMINI_API_KEY"
M.role_map = {
user = "user",
assistant = "model",
}
function M:is_disable_stream() return false end
---@param tool AvanteLLMTool
function M:transform_to_function_declaration(tool)
local input_schema_properties, required = Utils.llm_tool_param_fields_to_json_schema(tool.param.fields)
local parameters = nil
if not vim.tbl_isempty(input_schema_properties) then
parameters = {
type = "object",
properties = input_schema_properties,
required = required,
}
end
return {
name = tool.name,
description = tool.get_description and tool.get_description() or tool.description,
parameters = parameters,
}
end
function M:parse_messages(opts)
local contents = {}
local prev_role = nil
local tool_id_to_name = {}
vim.iter(opts.messages):each(function(message)
local role = message.role
if role == prev_role then
if role == M.role_map["user"] then
table.insert(
contents,
{ role = M.role_map["assistant"], parts = {
{ text = "Ok, I understand." },
} }
)
else
table.insert(contents, { role = M.role_map["user"], parts = {
{ text = "Ok" },
} })
end
end
prev_role = role
local parts = {}
local content_items = message.content
if type(content_items) == "string" then
table.insert(parts, { text = content_items })
elseif type(content_items) == "table" then
---@cast content_items AvanteLLMMessageContentItem[]
for _, item in ipairs(content_items) do
if type(item) == "string" then
table.insert(parts, { text = item })
elseif type(item) == "table" and item.type == "text" then
table.insert(parts, { text = item.text })
elseif type(item) == "table" and item.type == "image" then
table.insert(parts, {
inline_data = {
mime_type = "image/png",
data = item.source.data,
},
})
elseif type(item) == "table" and item.type == "tool_use" then
tool_id_to_name[item.id] = item.name
role = "model"
table.insert(parts, {
functionCall = {
name = item.name,
args = item.input,
},
})
elseif type(item) == "table" and item.type == "tool_result" then
role = "function"
local ok, content = pcall(vim.json.decode, item.content)
if not ok then content = item.content end
table.insert(parts, {
functionResponse = {
name = tool_id_to_name[item.tool_use_id],
response = {
name = tool_id_to_name[item.tool_use_id],
content = content,
},
},
})
elseif type(item) == "table" and item.type == "thinking" then
table.insert(parts, { text = item.thinking })
elseif type(item) == "table" and item.type == "redacted_thinking" then
table.insert(parts, { text = item.data })
end
end
end
table.insert(contents, { role = M.role_map[role] or role, parts = parts })
end)
if Clipboard.support_paste_image() and opts.image_paths then
for _, image_path in ipairs(opts.image_paths) do
local image_data = {
inline_data = {
mime_type = "image/png",
data = Clipboard.get_base64_content(image_path),
},
}
table.insert(contents[#contents].parts, image_data)
end
end
return {
systemInstruction = {
role = "user",
parts = {
{
text = opts.system_prompt,
},
},
},
contents = contents,
}
end
function M:parse_response(ctx, data_stream, _, opts)
local ok, json = pcall(vim.json.decode, data_stream)
if not ok then opts.on_stop({ reason = "error", error = json }) end
if json.candidates and #json.candidates > 0 then
local candidate = json.candidates[1]
---@type AvanteLLMToolUse[]
local tool_use_list = {}
if candidate.content.parts ~= nil then
for _, part in ipairs(candidate.content.parts) do
if part.text then
if opts.on_chunk then opts.on_chunk(part.text) end
OpenAI:add_text_message(ctx, part.text, "generating", opts)
elseif part.functionCall then
if not ctx.function_call_id then ctx.function_call_id = 0 end
ctx.function_call_id = ctx.function_call_id + 1
local tool_use = {
id = ctx.session_id .. "-" .. tostring(ctx.function_call_id),
name = part.functionCall.name,
input_json = vim.json.encode(part.functionCall.args),
}
table.insert(tool_use_list, tool_use)
OpenAI:add_tool_use_message(tool_use, "generated", opts)
end
end
end
if candidate.finishReason and candidate.finishReason == "STOP" then
OpenAI:finish_pending_messages(ctx, opts)
if #tool_use_list > 0 then
opts.on_stop({ reason = "tool_use", tool_use_list = tool_use_list })
else
opts.on_stop({ reason = "complete" })
end
end
else
OpenAI:finish_pending_messages(ctx, opts)
opts.on_stop({ reason = "complete" })
end
end
function M:parse_curl_args(prompt_opts)
local provider_conf, request_body = Providers.parse_config(self)
local disable_tools = provider_conf.disable_tools or false
request_body = vim.tbl_deep_extend("force", request_body, {
generationConfig = {
temperature = request_body.temperature,
maxOutputTokens = request_body.max_tokens,
},
})
request_body.temperature = nil
request_body.max_tokens = nil
local api_key = self.parse_api_key()
if api_key == nil then error("Cannot get the gemini api key!") end
local function_declarations = {}
if not disable_tools and prompt_opts.tools then
for _, tool in ipairs(prompt_opts.tools) do
table.insert(function_declarations, self:transform_to_function_declaration(tool))
end
end
if #function_declarations > 0 then
request_body.tools = {
{
functionDeclarations = function_declarations,
},
}
end
return {
url = Utils.url_join(
provider_conf.endpoint,
provider_conf.model .. ":streamGenerateContent?alt=sse&key=" .. api_key
),
proxy = provider_conf.proxy,
insecure = provider_conf.allow_insecure,
headers = { ["Content-Type"] = "application/json" },
body = vim.tbl_deep_extend("force", {}, self:parse_messages(prompt_opts), request_body),
}
end
return M