adding claude.lua

This commit is contained in:
2026-01-13 21:02:45 -05:00
parent 6b71c76517
commit 6268a57498
5 changed files with 2294 additions and 287 deletions

View File

@@ -361,4 +361,148 @@ function M.format_messages_for_claude(messages)
return formatted
end
--- Generate with tool use support for agentic mode
---@param messages table[] Conversation history
---@param context table Context information
---@param tool_definitions table Tool definitions
---@param callback fun(response: table|nil, error: string|nil) Callback with raw response
function M.generate_with_tools(messages, context, tool_definitions, callback)
local api_key = get_api_key()
if not api_key then
callback(nil, "Claude API key not configured")
return
end
local tools_module = require("codetyper.agent.tools")
local agent_prompts = require("codetyper.prompts.agent")
-- Build system prompt with agent instructions
local system_prompt = llm.build_system_prompt(context)
system_prompt = system_prompt .. "\n\n" .. agent_prompts.system
system_prompt = system_prompt .. "\n\n" .. agent_prompts.tool_instructions
-- Build request body with tools
local body = {
model = get_model(),
max_tokens = 4096,
system = system_prompt,
messages = M.format_messages_for_claude(messages),
tools = tools_module.to_claude_format(),
}
local json_body = vim.json.encode(body)
local cmd = {
"curl",
"-s",
"-X", "POST",
API_URL,
"-H", "Content-Type: application/json",
"-H", "x-api-key: " .. api_key,
"-H", "anthropic-version: 2023-06-01",
"-d", json_body,
}
vim.fn.jobstart(cmd, {
stdout_buffered = true,
on_stdout = function(_, data)
if not data or #data == 0 or (data[1] == "" and #data == 1) then
return
end
local response_text = table.concat(data, "\n")
local ok, response = pcall(vim.json.decode, response_text)
if not ok then
vim.schedule(function()
callback(nil, "Failed to parse Claude response")
end)
return
end
if response.error then
vim.schedule(function()
callback(nil, response.error.message or "Claude API error")
end)
return
end
-- Return raw response for parser to handle
vim.schedule(function()
callback(response, nil)
end)
end,
on_stderr = function(_, data)
if data and #data > 0 and data[1] ~= "" then
vim.schedule(function()
callback(nil, "Claude API request failed: " .. table.concat(data, "\n"))
end)
end
end,
on_exit = function(_, code)
if code ~= 0 then
vim.schedule(function()
callback(nil, "Claude API request failed with code: " .. code)
end)
end
end,
})
end
--- Format messages for Claude API
---@param messages table[] Internal message format
---@return table[] Claude API message format
function M.format_messages_for_claude(messages)
local formatted = {}
for _, msg in ipairs(messages) do
if msg.role == "user" then
if type(msg.content) == "table" then
-- Tool results
table.insert(formatted, {
role = "user",
content = msg.content,
})
else
table.insert(formatted, {
role = "user",
content = msg.content,
})
end
elseif msg.role == "assistant" then
-- Build content array for assistant messages
local content = {}
-- Add text if present
if msg.content and msg.content ~= "" then
table.insert(content, {
type = "text",
text = msg.content,
})
end
-- Add tool uses if present
if msg.tool_calls then
for _, tool_call in ipairs(msg.tool_calls) do
table.insert(content, {
type = "tool_use",
id = tool_call.id,
name = tool_call.name,
input = tool_call.parameters,
})
end
end
if #content > 0 then
table.insert(formatted, {
role = "assistant",
content = content,
})
end
end
end
return formatted
end
return M

View File

@@ -0,0 +1,531 @@
---Reference implementation:
---https://github.com/zbirenbaum/copilot.lua/blob/master/lua/copilot/auth.lua config file
---https://github.com/zed-industries/zed/blob/ad43bbbf5eda59eba65309735472e0be58b4f7dd/crates/copilot/src/copilot_chat.rs#L272 for authorization
---
---@class CopilotToken
---@field annotations_enabled boolean
---@field chat_enabled boolean
---@field chat_jetbrains_enabled boolean
---@field code_quote_enabled boolean
---@field codesearch boolean
---@field copilotignore_enabled boolean
---@field endpoints {api: string, ["origin-tracker"]: string, proxy: string, telemetry: string}
---@field expires_at integer
---@field individual boolean
---@field nes_enabled boolean
---@field prompt_8k boolean
---@field public_suggestions string
---@field refresh_in integer
---@field sku string
---@field snippy_load_test_enabled boolean
---@field telemetry string
---@field token string
---@field tracking_id string
---@field vsc_electron_fetcher boolean
---@field xcode boolean
---@field xcode_chat boolean
local curl = require("plenary.curl")
local Path = require("plenary.path")
local Utils = require("avante.utils")
local Providers = require("avante.providers")
local OpenAI = require("avante.providers").openai
local H = {}
---@class AvanteProviderFunctor
local M = {}
local copilot_path = vim.fn.stdpath("data") .. "/avante/github-copilot.json"
local lockfile_path = vim.fn.stdpath("data") .. "/avante/copilot-timer.lock"
-- Lockfile management
local function is_process_running(pid)
local result = vim.uv.kill(pid, 0)
if result ~= nil and result == 0 then
return true
else
return false
end
end
local function try_acquire_timer_lock()
local lockfile = Path:new(lockfile_path)
local tmp_lockfile = lockfile_path .. ".tmp." .. vim.fn.getpid()
Path:new(tmp_lockfile):write(tostring(vim.fn.getpid()), "w")
-- Check existing lock
if lockfile:exists() then
local content = lockfile:read()
local pid = tonumber(content)
if pid and is_process_running(pid) then
os.remove(tmp_lockfile)
return false -- Another instance is already managing
end
end
-- Attempt to take ownership
local success = os.rename(tmp_lockfile, lockfile_path)
if not success then
os.remove(tmp_lockfile)
return false
end
return true
end
local function start_manager_check_timer()
if M._manager_check_timer then
M._manager_check_timer:stop()
M._manager_check_timer:close()
end
M._manager_check_timer = vim.uv.new_timer()
M._manager_check_timer:start(
30000,
30000,
vim.schedule_wrap(function()
if not M._refresh_timer and try_acquire_timer_lock() then
M.setup_timer()
end
end)
)
end
---@class OAuthToken
---@field user string
---@field oauth_token string
---
---@return string
function H.get_oauth_token()
local xdg_config = vim.fn.expand("$XDG_CONFIG_HOME")
local os_name = Utils.get_os_name()
---@type string
local config_dir
if xdg_config and vim.fn.isdirectory(xdg_config) > 0 then
config_dir = xdg_config
elseif vim.tbl_contains({ "linux", "darwin" }, os_name) then
config_dir = vim.fn.expand("~/.config")
else
config_dir = vim.fn.expand("~/AppData/Local")
end
--- hosts.json (copilot.lua), apps.json (copilot.vim)
---@type Path[]
local paths = vim.iter({ "hosts.json", "apps.json" }):fold({}, function(acc, path)
local yason = Path:new(config_dir):joinpath("github-copilot", path)
if yason:exists() then
table.insert(acc, yason)
end
return acc
end)
if #paths == 0 then
error("You must setup copilot with either copilot.lua or copilot.vim", 2)
end
local yason = paths[1]
return vim
.iter(
---@type table<string, OAuthToken>
---@diagnostic disable-next-line: param-type-mismatch
vim.json.decode(yason:read())
)
:filter(function(k, _)
return k:match("github.com")
end)
---@param acc {oauth_token: string}
:fold({}, function(acc, _, v)
acc.oauth_token = v.oauth_token
return acc
end)
.oauth_token
end
H.chat_auth_url = "https://api.github.com/copilot_internal/v2/token"
function H.chat_completion_url(base_url)
return Utils.url_join(base_url, "/chat/completions")
end
function H.response_url(base_url)
return Utils.url_join(base_url, "/responses")
end
function H.refresh_token(async, force)
if not M.state then
error("internal initialization error")
end
async = async == nil and true or async
force = force or false
-- Do not refresh token if not forced or not expired
if
not force
and M.state.github_token
and M.state.github_token.expires_at
and M.state.github_token.expires_at > math.floor(os.time())
then
return false
end
local provider_conf = Providers.get_config("copilot")
local curl_opts = {
headers = {
["Authorization"] = "token " .. M.state.oauth_token,
["Accept"] = "application/json",
},
timeout = provider_conf.timeout,
proxy = provider_conf.proxy,
insecure = provider_conf.allow_insecure,
}
local function handle_response(response)
if response.status == 200 then
M.state.github_token = vim.json.decode(response.body)
local file = Path:new(copilot_path)
file:write(vim.json.encode(M.state.github_token), "w")
if not vim.g.avante_login then
vim.g.avante_login = true
end
-- If triggered synchronously, reset timer
if not async and M._refresh_timer then
M.setup_timer()
end
return true
else
error("Failed to get success response: " .. vim.inspect(response))
return false
end
end
if async then
curl.get(
H.chat_auth_url,
vim.tbl_deep_extend("force", {
callback = handle_response,
}, curl_opts)
)
else
local response = curl.get(H.chat_auth_url, curl_opts)
handle_response(response)
end
end
---@private
---@class AvanteCopilotState
---@field oauth_token string
---@field github_token CopilotToken?
M.state = nil
M.api_key_name = ""
M.tokenizer_id = "gpt-4o"
M.role_map = {
user = "user",
assistant = "assistant",
}
function M:is_disable_stream()
return false
end
setmetatable(M, { __index = OpenAI })
function M:list_models()
if M._model_list_cache then
return M._model_list_cache
end
if not M._is_setup then
M.setup()
end
-- refresh token synchronously, only if it has expired
-- (this should rarely happen, as we refresh the token in the background)
H.refresh_token(false, false)
local provider_conf = Providers.parse_config(self)
local headers = self:build_headers()
local curl_opts = {
headers = Utils.tbl_override(headers, self.extra_headers),
timeout = provider_conf.timeout,
proxy = provider_conf.proxy,
insecure = provider_conf.allow_insecure,
}
local function handle_response(response)
if response.status == 200 then
local body = vim.json.decode(response.body)
-- ref: https://github.com/CopilotC-Nvim/CopilotChat.nvim/blob/16d897fd43d07e3b54478ccdb2f8a16e4df4f45a/lua/CopilotChat/config/providers.lua#L171-L187
local models = vim.iter(body.data)
:filter(function(model)
return model.capabilities.type == "chat" and not vim.endswith(model.id, "paygo")
end)
:map(function(model)
return {
id = model.id,
display_name = model.name,
name = "copilot/" .. model.name .. " (" .. model.id .. ")",
provider_name = "copilot",
tokenizer = model.capabilities.tokenizer,
max_input_tokens = model.capabilities.limits.max_prompt_tokens,
max_output_tokens = model.capabilities.limits.max_output_tokens,
policy = not model["policy"] or model["policy"]["state"] == "enabled",
version = model.version,
}
end)
:totable()
M._model_list_cache = models
return models
else
error("Failed to get success response: " .. vim.inspect(response))
return {}
end
end
local response = curl.get((M.state.github_token.endpoints.api or "") .. "/models", curl_opts)
return handle_response(response)
end
function M:build_headers()
return {
["Authorization"] = "Bearer " .. M.state.github_token.token,
["User-Agent"] = "GitHubCopilotChat/0.26.7",
["Editor-Version"] = "vscode/1.105.1",
["Editor-Plugin-Version"] = "copilot-chat/0.26.7",
["Copilot-Integration-Id"] = "vscode-chat",
["Openai-Intent"] = "conversation-edits",
}
end
function M:parse_curl_args(prompt_opts)
-- refresh token synchronously, only if it has expired
-- (this should rarely happen, as we refresh the token in the background)
H.refresh_token(false, false)
local provider_conf, request_body = Providers.parse_config(self)
local use_response_api = Providers.resolve_use_response_api(provider_conf, prompt_opts)
local disable_tools = provider_conf.disable_tools or false
-- Apply OpenAI's set_allowed_params for Response API compatibility
OpenAI.set_allowed_params(provider_conf, request_body)
local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true
local tools = nil
if not disable_tools and prompt_opts.tools and not use_ReAct_prompt then
tools = {}
for _, tool in ipairs(prompt_opts.tools) do
local transformed_tool = OpenAI:transform_tool(tool)
-- Response API uses flattened tool structure
if use_response_api then
if transformed_tool.type == "function" and transformed_tool["function"] then
transformed_tool = {
type = "function",
name = transformed_tool["function"].name,
description = transformed_tool["function"].description,
parameters = transformed_tool["function"].parameters,
}
end
end
table.insert(tools, transformed_tool)
end
end
local headers = self:build_headers()
if prompt_opts.messages and #prompt_opts.messages > 0 then
local last_message = prompt_opts.messages[#prompt_opts.messages]
local initiator = last_message.role == "user" and "user" or "agent"
headers["X-Initiator"] = initiator
end
local parsed_messages = self:parse_messages(prompt_opts)
-- Build base body
local base_body = {
model = provider_conf.model,
stream = true,
tools = tools,
}
-- Response API uses 'input' instead of 'messages'
-- NOTE: Copilot doesn't support previous_response_id, always send full history
if use_response_api then
base_body.input = parsed_messages
-- Response API uses max_output_tokens instead of max_tokens/max_completion_tokens
if request_body.max_completion_tokens then
request_body.max_output_tokens = request_body.max_completion_tokens
request_body.max_completion_tokens = nil
end
if request_body.max_tokens then
request_body.max_output_tokens = request_body.max_tokens
request_body.max_tokens = nil
end
-- Response API doesn't use stream_options
base_body.stream_options = nil
base_body.include = { "reasoning.encrypted_content" }
base_body.reasoning = {
summary = "detailed",
}
base_body.truncation = "disabled"
else
base_body.messages = parsed_messages
base_body.stream_options = {
include_usage = true,
}
end
local base_url = M.state.github_token.endpoints.api or provider_conf.endpoint
local build_url = use_response_api and H.response_url or H.chat_completion_url
return {
url = build_url(base_url),
timeout = provider_conf.timeout,
proxy = provider_conf.proxy,
insecure = provider_conf.allow_insecure,
headers = Utils.tbl_override(headers, self.extra_headers),
body = vim.tbl_deep_extend("force", base_body, request_body),
}
end
M._refresh_timer = nil
function M.setup_timer()
if M._refresh_timer then
M._refresh_timer:stop()
M._refresh_timer:close()
end
-- Calculate time until token expires
local now = math.floor(os.time())
local expires_at = M.state.github_token and M.state.github_token.expires_at or now
local time_until_expiry = math.max(0, expires_at - now)
-- Refresh 2 minutes before expiration
local initial_interval = math.max(0, (time_until_expiry - 120) * 1000)
-- Regular interval of 28 minutes after the first refresh
local repeat_interval = 28 * 60 * 1000
M._refresh_timer = vim.uv.new_timer()
M._refresh_timer:start(
initial_interval,
repeat_interval,
vim.schedule_wrap(function()
H.refresh_token(true, true)
end)
)
end
function M.setup_file_watcher()
if M._file_watcher then
return
end
local copilot_token_file = Path:new(copilot_path)
M._file_watcher = vim.uv.new_fs_event()
M._file_watcher:start(
copilot_path,
{},
vim.schedule_wrap(function()
-- Reload token from file
if copilot_token_file:exists() then
local ok, token = pcall(vim.json.decode, copilot_token_file:read())
if ok then
M.state.github_token = token
end
end
end)
)
end
M._is_setup = false
function M.is_env_set()
local ok = pcall(function()
H.get_oauth_token()
end)
return ok
end
function M.setup()
local copilot_token_file = Path:new(copilot_path)
if not M.state then
M.state = {
github_token = nil,
oauth_token = H.get_oauth_token(),
}
end
-- Load and validate existing token
if copilot_token_file:exists() then
local ok, token = pcall(vim.json.decode, copilot_token_file:read())
if ok and token.expires_at and token.expires_at > math.floor(os.time()) then
M.state.github_token = token
end
end
-- Setup timer management
local timer_lock_acquired = try_acquire_timer_lock()
if timer_lock_acquired then
M.setup_timer()
else
vim.schedule(function()
H.refresh_token(true, false)
end)
end
M.setup_file_watcher()
start_manager_check_timer()
require("avante.tokenizers").setup(M.tokenizer_id)
vim.g.avante_login = true
M._is_setup = true
end
function M.cleanup()
-- Cleanup refresh timer
if M._refresh_timer then
M._refresh_timer:stop()
M._refresh_timer:close()
M._refresh_timer = nil
-- Remove lockfile if we were the manager
local lockfile = Path:new(lockfile_path)
if lockfile:exists() then
local content = lockfile:read()
local pid = tonumber(content)
if pid and pid == vim.fn.getpid() then
lockfile:rm()
end
end
end
-- Cleanup manager check timer
if M._manager_check_timer then
M._manager_check_timer:stop()
M._manager_check_timer:close()
M._manager_check_timer = nil
end
-- Cleanup file watcher
if M._file_watcher then
---@diagnostic disable-next-line: param-type-mismatch
M._file_watcher:stop()
M._file_watcher = nil
end
end
-- Register cleanup on Neovim exit
vim.api.nvim_create_autocmd("VimLeavePre", {
callback = function()
M.cleanup()
end,
})
return M

View File

@@ -0,0 +1,361 @@
local Utils = require("avante.utils")
local Providers = require("avante.providers")
local Clipboard = require("avante.clipboard")
local OpenAI = require("avante.providers").openai
local Prompts = require("avante.utils.prompts")
---@class AvanteProviderFunctor
local M = {}
M.api_key_name = "GEMINI_API_KEY"
M.role_map = {
user = "user",
assistant = "model",
}
function M:is_disable_stream()
return false
end
---@param tool AvanteLLMTool
function M:transform_to_function_declaration(tool)
local input_schema_properties, required = Utils.llm_tool_param_fields_to_json_schema(tool.param.fields)
local parameters = nil
if not vim.tbl_isempty(input_schema_properties) then
parameters = {
type = "object",
properties = input_schema_properties,
required = required,
}
end
return {
name = tool.name,
description = tool.get_description and tool.get_description() or tool.description,
parameters = parameters,
}
end
function M:parse_messages(opts)
local provider_conf, _ = Providers.parse_config(self)
local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true
local contents = {}
local prev_role = nil
local tool_id_to_name = {}
vim.iter(opts.messages):each(function(message)
local role = message.role
if role == prev_role then
if role == M.role_map["user"] then
table.insert(
contents,
{ role = M.role_map["assistant"], parts = {
{ text = "Ok, I understand." },
} }
)
else
table.insert(contents, { role = M.role_map["user"], parts = {
{ text = "Ok" },
} })
end
end
prev_role = role
local parts = {}
local content_items = message.content
if type(content_items) == "string" then
table.insert(parts, { text = content_items })
elseif type(content_items) == "table" then
---@cast content_items AvanteLLMMessageContentItem[]
for _, item in ipairs(content_items) do
if type(item) == "string" then
table.insert(parts, { text = item })
elseif type(item) == "table" and item.type == "text" then
table.insert(parts, { text = item.text })
elseif type(item) == "table" and item.type == "image" then
table.insert(parts, {
inline_data = {
mime_type = "image/png",
data = item.source.data,
},
})
elseif type(item) == "table" and item.type == "tool_use" and not use_ReAct_prompt then
tool_id_to_name[item.id] = item.name
role = "model"
table.insert(parts, {
functionCall = {
name = item.name,
args = item.input,
},
})
elseif type(item) == "table" and item.type == "tool_result" and not use_ReAct_prompt then
role = "function"
local ok, content = pcall(vim.json.decode, item.content)
if not ok then
content = item.content
end
-- item.name here refers to the name of the tool that was called,
-- which is available in the tool_result content item prepared by llm.lua
local tool_name = item.name
if not tool_name then
-- Fallback, though item.name should ideally always be present for tool_result
tool_name = tool_id_to_name[item.tool_use_id]
end
table.insert(parts, {
functionResponse = {
name = tool_name,
response = {
name = tool_name, -- Gemini API requires the name in the response object as well
content = content,
},
},
})
elseif type(item) == "table" and item.type == "thinking" then
table.insert(parts, { text = item.thinking })
elseif type(item) == "table" and item.type == "redacted_thinking" then
table.insert(parts, { text = item.data })
end
end
if not provider_conf.disable_tools and use_ReAct_prompt then
if content_items[1].type == "tool_result" then
local tool_use_msg = nil
for _, msg_ in ipairs(opts.messages) do
if type(msg_.content) == "table" and #msg_.content > 0 then
if
msg_.content[1].type == "tool_use"
and msg_.content[1].id == content_items[1].tool_use_id
then
tool_use_msg = msg_
break
end
end
end
if tool_use_msg then
table.insert(contents, {
role = "model",
parts = {
{ text = Utils.tool_use_to_xml(tool_use_msg.content[1]) },
},
})
role = "user"
table.insert(parts, {
text = "The result of tool use "
.. Utils.tool_use_to_xml(tool_use_msg.content[1])
.. " is:\n",
})
table.insert(parts, {
text = content_items[1].content,
})
end
end
end
end
if #parts > 0 then
table.insert(contents, { role = M.role_map[role] or role, parts = parts })
end
end)
if Clipboard.support_paste_image() and opts.image_paths then
for _, image_path in ipairs(opts.image_paths) do
local image_data = {
inline_data = {
mime_type = "image/png",
data = Clipboard.get_base64_content(image_path),
},
}
table.insert(contents[#contents].parts, image_data)
end
end
local system_prompt = opts.system_prompt
if use_ReAct_prompt then
system_prompt = Prompts.get_ReAct_system_prompt(provider_conf, opts)
end
return {
systemInstruction = {
role = "user",
parts = {
{
text = system_prompt,
},
},
},
contents = contents,
}
end
--- Prepares the main request body for Gemini-like APIs.
---@param provider_instance AvanteProviderFunctor The provider instance (self).
---@param prompt_opts AvantePromptOptions Prompt options including messages, tools, system_prompt.
---@param provider_conf table Provider configuration from config.lua (e.g., model, top-level temperature/max_tokens).
---@param request_body_ table Request-specific overrides, typically from provider_conf.request_config_overrides.
---@return table The fully constructed request body.
function M.prepare_request_body(provider_instance, prompt_opts, provider_conf, request_body_)
local request_body = {}
request_body.generationConfig = request_body_.generationConfig or {}
local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true
if use_ReAct_prompt then
request_body.generationConfig.stopSequences = { "</tool_use>" }
end
local disable_tools = provider_conf.disable_tools or false
if not use_ReAct_prompt and not disable_tools and prompt_opts.tools then
local function_declarations = {}
for _, tool in ipairs(prompt_opts.tools) do
table.insert(function_declarations, provider_instance:transform_to_function_declaration(tool))
end
if #function_declarations > 0 then
request_body.tools = {
{
functionDeclarations = function_declarations,
},
}
end
end
return vim.tbl_deep_extend("force", {}, provider_instance:parse_messages(prompt_opts), request_body)
end
---@param usage avante.GeminiTokenUsage | nil
---@return avante.LLMTokenUsage | nil
function M.transform_gemini_usage(usage)
if not usage then
return nil
end
---@type avante.LLMTokenUsage
local res = {
prompt_tokens = usage.promptTokenCount,
completion_tokens = usage.candidatesTokenCount,
}
return res
end
function M:parse_response(ctx, data_stream, _, opts)
local ok, jsn = pcall(vim.json.decode, data_stream)
if not ok then
opts.on_stop({ reason = "error", error = "Failed to parse JSON response: " .. tostring(jsn) })
return
end
if opts.update_tokens_usage and jsn.usageMetadata and jsn.usageMetadata ~= nil then
local usage = M.transform_gemini_usage(jsn.usageMetadata)
if usage ~= nil then
opts.update_tokens_usage(usage)
end
end
-- Handle prompt feedback first, as it might indicate an overall issue with the prompt
if jsn.promptFeedback and jsn.promptFeedback.blockReason then
local feedback = jsn.promptFeedback
OpenAI:finish_pending_messages(ctx, opts) -- Ensure any pending messages are cleared
opts.on_stop({
reason = "error",
error = "Prompt blocked or filtered. Reason: " .. feedback.blockReason,
details = feedback,
})
return
end
if jsn.candidates and #jsn.candidates > 0 then
local candidate = jsn.candidates[1]
---@type AvanteLLMToolUse[]
ctx.tool_use_list = ctx.tool_use_list or {}
-- Check if candidate.content and candidate.content.parts exist before iterating
if candidate.content and candidate.content.parts then
for _, part in ipairs(candidate.content.parts) do
if part.text then
if opts.on_chunk then
opts.on_chunk(part.text)
end
OpenAI:add_text_message(ctx, part.text, "generating", opts)
elseif part.functionCall then
if not ctx.function_call_id then
ctx.function_call_id = 0
end
ctx.function_call_id = ctx.function_call_id + 1
local tool_use = {
id = ctx.turn_id .. "-" .. tostring(ctx.function_call_id),
name = part.functionCall.name,
input_json = vim.json.encode(part.functionCall.args),
}
table.insert(ctx.tool_use_list, tool_use)
OpenAI:add_tool_use_message(ctx, tool_use, "generated", opts)
end
end
end
-- Check for finishReason to determine if this candidate's stream is done.
if candidate.finishReason then
OpenAI:finish_pending_messages(ctx, opts)
local reason_str = candidate.finishReason
local stop_details = { finish_reason = reason_str }
stop_details.usage = M.transform_gemini_usage(jsn.usageMetadata)
if reason_str == "TOOL_CODE" then
-- Model indicates a tool-related stop.
-- The tool_use list is added to the table in llm.lua
opts.on_stop(vim.tbl_deep_extend("force", { reason = "tool_use" }, stop_details))
elseif reason_str == "STOP" then
if ctx.tool_use_list and #ctx.tool_use_list > 0 then
-- Natural stop, but tools were found in this final chunk.
opts.on_stop(vim.tbl_deep_extend("force", { reason = "tool_use" }, stop_details))
else
-- Natural stop, no tools in this final chunk.
-- llm.lua will check its accumulated tools if tool_choice was active.
opts.on_stop(vim.tbl_deep_extend("force", { reason = "complete" }, stop_details))
end
elseif reason_str == "MAX_TOKENS" then
opts.on_stop(vim.tbl_deep_extend("force", { reason = "max_tokens" }, stop_details))
elseif reason_str == "SAFETY" or reason_str == "RECITATION" then
opts.on_stop(
vim.tbl_deep_extend(
"force",
{ reason = "error", error = "Generation stopped: " .. reason_str },
stop_details
)
)
else -- OTHER, FINISH_REASON_UNSPECIFIED, or any other unhandled reason.
opts.on_stop(
vim.tbl_deep_extend(
"force",
{ reason = "error", error = "Generation stopped with unhandled reason: " .. reason_str },
stop_details
)
)
end
end
-- If no finishReason, it's an intermediate chunk; do not call on_stop.
end
end
---@param prompt_opts AvantePromptOptions
---@return AvanteCurlOutput|nil
function M:parse_curl_args(prompt_opts)
local provider_conf, request_body = Providers.parse_config(self)
local api_key = self:parse_api_key()
if api_key == nil then
Utils.error("Gemini: API key is not set. Please set " .. M.api_key_name)
return nil
end
return {
url = Utils.url_join(
provider_conf.endpoint,
provider_conf.model .. ":streamGenerateContent?alt=sse&key=" .. api_key
),
proxy = provider_conf.proxy,
insecure = provider_conf.allow_insecure,
headers = Utils.tbl_override({ ["Content-Type"] = "application/json" }, self.extra_headers),
body = M.prepare_request_body(self, prompt_opts, provider_conf, request_body),
}
end
return M

View File

@@ -8,19 +8,19 @@ local llm = require("codetyper.llm")
--- Get Ollama host from config
---@return string Host URL
local function get_host()
local codetyper = require("codetyper")
local config = codetyper.get_config()
local codetyper = require("codetyper")
local config = codetyper.get_config()
return config.llm.ollama.host
return config.llm.ollama.host
end
--- Get model from config
---@return string Model name
local function get_model()
local codetyper = require("codetyper")
local config = codetyper.get_config()
local codetyper = require("codetyper")
local config = codetyper.get_config()
return config.llm.ollama.model
return config.llm.ollama.model
end
--- Build request body for Ollama API
@@ -28,93 +28,96 @@ end
---@param context table Context information
---@return table Request body
local function build_request_body(prompt, context)
local system_prompt = llm.build_system_prompt(context)
local system_prompt = llm.build_system_prompt(context)
return {
model = get_model(),
system = system_prompt,
prompt = prompt,
stream = false,
options = {
temperature = 0.2,
num_predict = 4096,
},
}
return {
model = get_model(),
system = system_prompt,
prompt = prompt,
stream = false,
options = {
temperature = 0.2,
num_predict = 4096,
},
}
end
--- Make HTTP request to Ollama API
---@param body table Request body
---@param callback fun(response: string|nil, error: string|nil, usage: table|nil) Callback function
local function make_request(body, callback)
local host = get_host()
local url = host .. "/api/generate"
local json_body = vim.json.encode(body)
local host = get_host()
local url = host .. "/api/generate"
local json_body = vim.json.encode(body)
local cmd = {
"curl",
"-s",
"-X", "POST",
url,
"-H", "Content-Type: application/json",
"-d", json_body,
}
local cmd = {
"curl",
"-s",
"-X",
"POST",
url,
"-H",
"Content-Type: application/json",
"-d",
json_body,
}
vim.fn.jobstart(cmd, {
stdout_buffered = true,
on_stdout = function(_, data)
if not data or #data == 0 or (data[1] == "" and #data == 1) then
return
end
vim.fn.jobstart(cmd, {
stdout_buffered = true,
on_stdout = function(_, data)
if not data or #data == 0 or (data[1] == "" and #data == 1) then
return
end
local response_text = table.concat(data, "\n")
local ok, response = pcall(vim.json.decode, response_text)
local response_text = table.concat(data, "\n")
local ok, response = pcall(vim.json.decode, response_text)
if not ok then
vim.schedule(function()
callback(nil, "Failed to parse Ollama response", nil)
end)
return
end
if not ok then
vim.schedule(function()
callback(nil, "Failed to parse Ollama response", nil)
end)
return
end
if response.error then
vim.schedule(function()
callback(nil, response.error or "Ollama API error", nil)
end)
return
end
if response.error then
vim.schedule(function()
callback(nil, response.error or "Ollama API error", nil)
end)
return
end
-- Extract usage info
local usage = {
prompt_tokens = response.prompt_eval_count or 0,
response_tokens = response.eval_count or 0,
}
-- Extract usage info
local usage = {
prompt_tokens = response.prompt_eval_count or 0,
response_tokens = response.eval_count or 0,
}
if response.response then
local code = llm.extract_code(response.response)
vim.schedule(function()
callback(code, nil, usage)
end)
else
vim.schedule(function()
callback(nil, "No response from Ollama", nil)
end)
end
end,
on_stderr = function(_, data)
if data and #data > 0 and data[1] ~= "" then
vim.schedule(function()
callback(nil, "Ollama API request failed: " .. table.concat(data, "\n"), nil)
end)
end
end,
on_exit = function(_, code)
if code ~= 0 then
vim.schedule(function()
callback(nil, "Ollama API request failed with code: " .. code, nil)
end)
end
end,
})
if response.response then
local code = llm.extract_code(response.response)
vim.schedule(function()
callback(code, nil, usage)
end)
else
vim.schedule(function()
callback(nil, "No response from Ollama", nil)
end)
end
end,
on_stderr = function(_, data)
if data and #data > 0 and data[1] ~= "" then
vim.schedule(function()
callback(nil, "Ollama API request failed: " .. table.concat(data, "\n"), nil)
end)
end
end,
on_exit = function(_, code)
if code ~= 0 then
vim.schedule(function()
callback(nil, "Ollama API request failed with code: " .. code, nil)
end)
end
end,
})
end
--- Generate code using Ollama API
@@ -122,111 +125,107 @@ end
---@param context table Context information
---@param callback fun(response: string|nil, error: string|nil) Callback function
function M.generate(prompt, context, callback)
local logs = require("codetyper.agent.logs")
local model = get_model()
local logs = require("codetyper.agent.logs")
local model = get_model()
-- Log the request
logs.request("ollama", model)
logs.thinking("Building request body...")
-- Log the request
logs.request("ollama", model)
logs.thinking("Building request body...")
local body = build_request_body(prompt, context)
local body = build_request_body(prompt, context)
-- Estimate prompt tokens
local prompt_estimate = logs.estimate_tokens(vim.json.encode(body))
logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate))
logs.thinking("Sending to Ollama API...")
-- Estimate prompt tokens
local prompt_estimate = logs.estimate_tokens(vim.json.encode(body))
logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate))
logs.thinking("Sending to Ollama API...")
utils.notify("Sending request to Ollama...", vim.log.levels.INFO)
utils.notify("Sending request to Ollama...", vim.log.levels.INFO)
make_request(body, function(response, err, usage)
if err then
logs.error(err)
utils.notify(err, vim.log.levels.ERROR)
callback(nil, err)
else
-- Log token usage
if usage then
logs.response(
usage.prompt_tokens or 0,
usage.response_tokens or 0,
"end_turn"
)
end
logs.thinking("Response received, extracting code...")
logs.info("Code generated successfully")
utils.notify("Code generated successfully", vim.log.levels.INFO)
callback(response, nil)
end
end)
make_request(body, function(response, err, usage)
if err then
logs.error(err)
utils.notify(err, vim.log.levels.ERROR)
callback(nil, err)
else
-- Log token usage
if usage then
logs.response(usage.prompt_tokens or 0, usage.response_tokens or 0, "end_turn")
end
logs.thinking("Response received, extracting code...")
logs.info("Code generated successfully")
utils.notify("Code generated successfully", vim.log.levels.INFO)
callback(response, nil)
end
end)
end
--- Check if Ollama is reachable
---@param callback fun(ok: boolean, error: string|nil) Callback function
function M.health_check(callback)
local host = get_host()
local host = get_host()
local cmd = { "curl", "-s", host .. "/api/tags" }
local cmd = { "curl", "-s", host .. "/api/tags" }
vim.fn.jobstart(cmd, {
stdout_buffered = true,
on_stdout = function(_, data)
if data and #data > 0 and data[1] ~= "" then
vim.schedule(function()
callback(true, nil)
end)
end
end,
on_exit = function(_, code)
if code ~= 0 then
vim.schedule(function()
callback(false, "Cannot connect to Ollama at " .. host)
end)
end
end,
})
vim.fn.jobstart(cmd, {
stdout_buffered = true,
on_stdout = function(_, data)
if data and #data > 0 and data[1] ~= "" then
vim.schedule(function()
callback(true, nil)
end)
end
end,
on_exit = function(_, code)
if code ~= 0 then
vim.schedule(function()
callback(false, "Cannot connect to Ollama at " .. host)
end)
end
end,
})
end
--- Check if Ollama is properly configured
---@return boolean, string? Valid status and optional error message
function M.validate()
local host = get_host()
if not host or host == "" then
return false, "Ollama host not configured"
end
local model = get_model()
if not model or model == "" then
return false, "Ollama model not configured"
end
return true
local host = get_host()
if not host or host == "" then
return false, "Ollama host not configured"
end
local model = get_model()
if not model or model == "" then
return false, "Ollama model not configured"
end
return true
end
--- Build system prompt for agent mode with tool instructions
---@param context table Context information
---@return string System prompt
local function build_agent_system_prompt(context)
local agent_prompts = require("codetyper.prompts.agent")
local tools_module = require("codetyper.agent.tools")
local agent_prompts = require("codetyper.prompts.agent")
local tools_module = require("codetyper.agent.tools")
local system_prompt = agent_prompts.system .. "\n\n"
system_prompt = system_prompt .. tools_module.to_prompt_format() .. "\n\n"
system_prompt = system_prompt .. agent_prompts.tool_instructions
local system_prompt = agent_prompts.system .. "\n\n"
system_prompt = system_prompt .. tools_module.to_prompt_format() .. "\n\n"
system_prompt = system_prompt .. agent_prompts.tool_instructions
-- Add context about current file if available
if context.file_path then
system_prompt = system_prompt .. "\n\nCurrent working context:\n"
system_prompt = system_prompt .. "- File: " .. context.file_path .. "\n"
if context.language then
system_prompt = system_prompt .. "- Language: " .. context.language .. "\n"
end
end
-- Add context about current file if available
if context.file_path then
system_prompt = system_prompt .. "\n\nCurrent working context:\n"
system_prompt = system_prompt .. "- File: " .. context.file_path .. "\n"
if context.language then
system_prompt = system_prompt .. "- Language: " .. context.language .. "\n"
end
end
-- Add project root info
local root = utils.get_project_root()
if root then
system_prompt = system_prompt .. "- Project root: " .. root .. "\n"
end
-- Add project root info
local root = utils.get_project_root()
if root then
system_prompt = system_prompt .. "- Project root: " .. root .. "\n"
end
return system_prompt
return system_prompt
end
--- Build request body for Ollama API with tools (chat format)
@@ -234,114 +233,117 @@ end
---@param context table Context information
---@return table Request body
local function build_tools_request_body(messages, context)
local system_prompt = build_agent_system_prompt(context)
local system_prompt = build_agent_system_prompt(context)
-- Convert messages to Ollama chat format
local ollama_messages = {}
for _, msg in ipairs(messages) do
local content = msg.content
-- Handle complex content (like tool results)
if type(content) == "table" then
local text_parts = {}
for _, part in ipairs(content) do
if part.type == "tool_result" then
table.insert(text_parts, "[" .. (part.name or "tool") .. " result]: " .. (part.content or ""))
elseif part.type == "text" then
table.insert(text_parts, part.text or "")
end
end
content = table.concat(text_parts, "\n")
end
-- Convert messages to Ollama chat format
local ollama_messages = {}
for _, msg in ipairs(messages) do
local content = msg.content
-- Handle complex content (like tool results)
if type(content) == "table" then
local text_parts = {}
for _, part in ipairs(content) do
if part.type == "tool_result" then
table.insert(text_parts, "[" .. (part.name or "tool") .. " result]: " .. (part.content or ""))
elseif part.type == "text" then
table.insert(text_parts, part.text or "")
end
end
content = table.concat(text_parts, "\n")
end
table.insert(ollama_messages, {
role = msg.role,
content = content,
})
end
table.insert(ollama_messages, {
role = msg.role,
content = content,
})
end
return {
model = get_model(),
messages = ollama_messages,
system = system_prompt,
stream = false,
options = {
temperature = 0.3,
num_predict = 4096,
},
}
return {
model = get_model(),
messages = ollama_messages,
system = system_prompt,
stream = false,
options = {
temperature = 0.3,
num_predict = 4096,
},
}
end
--- Make HTTP request to Ollama chat API
---@param body table Request body
---@param callback fun(response: string|nil, error: string|nil, usage: table|nil) Callback function
local function make_chat_request(body, callback)
local host = get_host()
local url = host .. "/api/chat"
local json_body = vim.json.encode(body)
local host = get_host()
local url = host .. "/api/chat"
local json_body = vim.json.encode(body)
local cmd = {
"curl",
"-s",
"-X", "POST",
url,
"-H", "Content-Type: application/json",
"-d", json_body,
}
local cmd = {
"curl",
"-s",
"-X",
"POST",
url,
"-H",
"Content-Type: application/json",
"-d",
json_body,
}
vim.fn.jobstart(cmd, {
stdout_buffered = true,
on_stdout = function(_, data)
if not data or #data == 0 or (data[1] == "" and #data == 1) then
return
end
vim.fn.jobstart(cmd, {
stdout_buffered = true,
on_stdout = function(_, data)
if not data or #data == 0 or (data[1] == "" and #data == 1) then
return
end
local response_text = table.concat(data, "\n")
local ok, response = pcall(vim.json.decode, response_text)
local response_text = table.concat(data, "\n")
local ok, response = pcall(vim.json.decode, response_text)
if not ok then
vim.schedule(function()
callback(nil, "Failed to parse Ollama response", nil)
end)
return
end
if not ok then
vim.schedule(function()
callback(nil, "Failed to parse Ollama response", nil)
end)
return
end
if response.error then
vim.schedule(function()
callback(nil, response.error or "Ollama API error", nil)
end)
return
end
if response.error then
vim.schedule(function()
callback(nil, response.error or "Ollama API error", nil)
end)
return
end
-- Extract usage info
local usage = {
prompt_tokens = response.prompt_eval_count or 0,
response_tokens = response.eval_count or 0,
}
-- Extract usage info
local usage = {
prompt_tokens = response.prompt_eval_count or 0,
response_tokens = response.eval_count or 0,
}
-- Return the message content for agent parsing
if response.message and response.message.content then
vim.schedule(function()
callback(response.message.content, nil, usage)
end)
else
vim.schedule(function()
callback(nil, "No response from Ollama", nil)
end)
end
end,
on_stderr = function(_, data)
if data and #data > 0 and data[1] ~= "" then
vim.schedule(function()
callback(nil, "Ollama API request failed: " .. table.concat(data, "\n"), nil)
end)
end
end,
on_exit = function(_, code)
if code ~= 0 then
-- Don't double-report errors
end
end,
})
-- Return the message content for agent parsing
if response.message and response.message.content then
vim.schedule(function()
callback(response.message.content, nil, usage)
end)
else
vim.schedule(function()
callback(nil, "No response from Ollama", nil)
end)
end
end,
on_stderr = function(_, data)
if data and #data > 0 and data[1] ~= "" then
vim.schedule(function()
callback(nil, "Ollama API request failed: " .. table.concat(data, "\n"), nil)
end)
end
end,
on_exit = function(_, code)
if code ~= 0 then
-- Don't double-report errors
end
end,
})
end
--- Generate response with tools using Ollama API
@@ -350,50 +352,46 @@ end
---@param tools table Tool definitions (embedded in prompt for Ollama)
---@param callback fun(response: string|nil, error: string|nil) Callback function
function M.generate_with_tools(messages, context, tools, callback)
local logs = require("codetyper.agent.logs")
local logs = require("codetyper.agent.logs")
-- Log the request
local model = get_model()
logs.request("ollama", model)
logs.thinking("Preparing API request...")
-- Log the request
local model = get_model()
logs.request("ollama", model)
logs.thinking("Preparing API request...")
local body = build_tools_request_body(messages, context)
local body = build_tools_request_body(messages, context)
-- Estimate prompt tokens
local prompt_estimate = logs.estimate_tokens(vim.json.encode(body))
logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate))
-- Estimate prompt tokens
local prompt_estimate = logs.estimate_tokens(vim.json.encode(body))
logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate))
make_chat_request(body, function(response, err, usage)
if err then
logs.error(err)
callback(nil, err)
else
-- Log token usage
if usage then
logs.response(
usage.prompt_tokens or 0,
usage.response_tokens or 0,
"end_turn"
)
end
make_chat_request(body, function(response, err, usage)
if err then
logs.error(err)
callback(nil, err)
else
-- Log token usage
if usage then
logs.response(usage.prompt_tokens or 0, usage.response_tokens or 0, "end_turn")
end
-- Log if response contains tool calls
if response then
local parser = require("codetyper.agent.parser")
local parsed = parser.parse_ollama_response(response)
if #parsed.tool_calls > 0 then
for _, tc in ipairs(parsed.tool_calls) do
logs.thinking("Tool call: " .. tc.name)
end
end
if parsed.text and parsed.text ~= "" then
logs.thinking("Response contains text")
end
end
-- Log if response contains tool calls
if response then
local parser = require("codetyper.agent.parser")
local parsed = parser.parse_ollama_response(response)
if #parsed.tool_calls > 0 then
for _, tc in ipairs(parsed.tool_calls) do
logs.thinking("Tool call: " .. tc.name)
end
end
if parsed.text and parsed.text ~= "" then
logs.thinking("Response contains text")
end
end
callback(response, nil)
end
end)
callback(response, nil)
end
end)
end
return M

View File

@@ -0,0 +1,973 @@
local Utils = require("avante.utils")
local Config = require("avante.config")
local Clipboard = require("avante.clipboard")
local Providers = require("avante.providers")
local HistoryMessage = require("avante.history.message")
local ReActParser = require("avante.libs.ReAct_parser2")
local JsonParser = require("avante.libs.jsonparser")
local Prompts = require("avante.utils.prompts")
local LlmTools = require("avante.llm_tools")
---@class AvanteProviderFunctor
local M = {}
M.api_key_name = "OPENAI_API_KEY"
M.role_map = {
user = "user",
assistant = "assistant",
}
function M:is_disable_stream()
return false
end
---@param tool AvanteLLMTool
---@return AvanteOpenAITool
function M:transform_tool(tool)
local input_schema_properties, required = Utils.llm_tool_param_fields_to_json_schema(tool.param.fields)
---@type AvanteOpenAIToolFunctionParameters
local parameters = {
type = "object",
properties = input_schema_properties,
required = required,
additionalProperties = false,
}
---@type AvanteOpenAITool
local res = {
type = "function",
["function"] = {
name = tool.name,
description = tool.get_description and tool.get_description() or tool.description,
parameters = parameters,
},
}
return res
end
function M.is_openrouter(url)
return url:match("^https://openrouter%.ai/")
end
function M.is_mistral(url)
return url:match("^https://api%.mistral%.ai/")
end
---@param opts AvantePromptOptions
function M.get_user_message(opts)
vim.deprecate("get_user_message", "parse_messages", "0.1.0", "avante.nvim")
return table.concat(
vim.iter(opts.messages)
:filter(function(_, value)
return value == nil or value.role ~= "user"
end)
:fold({}, function(acc, value)
acc = vim.list_extend({}, acc)
acc = vim.list_extend(acc, { value.content })
return acc
end),
"\n"
)
end
function M.is_reasoning_model(model)
return model
and (string.match(model, "^o%d+") ~= nil or (string.match(model, "gpt%-5") ~= nil and model ~= "gpt-5-chat"))
end
function M.set_allowed_params(provider_conf, request_body)
local use_response_api = Providers.resolve_use_response_api(provider_conf, nil)
if M.is_reasoning_model(provider_conf.model) then
-- Reasoning models have specific parameter requirements
request_body.temperature = 1
-- Response API doesn't support temperature for reasoning models
if use_response_api then
request_body.temperature = nil
end
else
request_body.reasoning_effort = nil
request_body.reasoning = nil
end
-- If max_tokens is set in config, unset max_completion_tokens
if request_body.max_tokens then
request_body.max_completion_tokens = nil
end
-- Handle Response API specific parameters
if use_response_api then
-- Convert reasoning_effort to reasoning object for Response API
if request_body.reasoning_effort then
request_body.reasoning = {
effort = request_body.reasoning_effort,
}
request_body.reasoning_effort = nil
end
-- Response API doesn't support some parameters
-- Remove unsupported parameters for Response API
local unsupported_params = {
"top_p",
"frequency_penalty",
"presence_penalty",
"logit_bias",
"logprobs",
"top_logprobs",
"n",
}
for _, param in ipairs(unsupported_params) do
request_body[param] = nil
end
end
end
function M:parse_messages(opts)
local messages = {}
local provider_conf, _ = Providers.parse_config(self)
local use_response_api = Providers.resolve_use_response_api(provider_conf, opts)
local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true
local system_prompt = opts.system_prompt
if use_ReAct_prompt then
system_prompt = Prompts.get_ReAct_system_prompt(provider_conf, opts)
end
if self.is_reasoning_model(provider_conf.model) then
table.insert(messages, { role = "developer", content = system_prompt })
else
table.insert(messages, { role = "system", content = system_prompt })
end
local has_tool_use = false
vim.iter(opts.messages):each(function(msg)
if type(msg.content) == "string" then
table.insert(messages, { role = self.role_map[msg.role], content = msg.content })
elseif type(msg.content) == "table" then
-- Check if this is a reasoning message (object with type "reasoning")
if msg.content.type == "reasoning" then
-- Add reasoning message directly (for Response API)
table.insert(messages, {
type = "reasoning",
id = msg.content.id,
encrypted_content = msg.content.encrypted_content,
summary = msg.content.summary,
})
return
end
local content = {}
local tool_calls = {}
local tool_results = {}
for _, item in ipairs(msg.content) do
if type(item) == "string" then
table.insert(content, { type = "text", text = item })
elseif item.type == "text" then
table.insert(content, { type = "text", text = item.text })
elseif item.type == "image" then
table.insert(content, {
type = "image_url",
image_url = {
url = "data:"
.. item.source.media_type
.. ";"
.. item.source.type
.. ","
.. item.source.data,
},
})
elseif item.type == "reasoning" then
-- Add reasoning message directly (for Response API)
table.insert(messages, {
type = "reasoning",
id = item.id,
encrypted_content = item.encrypted_content,
summary = item.summary,
})
elseif item.type == "tool_use" and not use_ReAct_prompt then
has_tool_use = true
table.insert(tool_calls, {
id = item.id,
type = "function",
["function"] = { name = item.name, arguments = vim.json.encode(item.input) },
})
elseif item.type == "tool_result" and has_tool_use and not use_ReAct_prompt then
table.insert(
tool_results,
{
tool_call_id = item.tool_use_id,
content = item.is_error and "Error: " .. item.content or item.content,
}
)
end
end
if not provider_conf.disable_tools and use_ReAct_prompt then
if msg.content[1].type == "tool_result" then
local tool_use_msg = nil
for _, msg_ in ipairs(opts.messages) do
if type(msg_.content) == "table" and #msg_.content > 0 then
if
msg_.content[1].type == "tool_use"
and msg_.content[1].id == msg.content[1].tool_use_id
then
tool_use_msg = msg_
break
end
end
end
if tool_use_msg then
msg.role = "user"
table.insert(content, {
type = "text",
text = "The result of tool use "
.. Utils.tool_use_to_xml(tool_use_msg.content[1])
.. " is:\n",
})
table.insert(content, {
type = "text",
text = msg.content[1].content,
})
end
end
end
if #content > 0 then
table.insert(messages, { role = self.role_map[msg.role], content = content })
end
if not provider_conf.disable_tools and not use_ReAct_prompt then
if #tool_calls > 0 then
-- Only skip tool_calls if using Response API with previous_response_id support
-- Copilot uses Response API format but doesn't support previous_response_id
local should_include_tool_calls = not use_response_api
or not provider_conf.support_previous_response_id
if should_include_tool_calls then
-- For Response API without previous_response_id support (like Copilot),
-- convert tool_calls to function_call items in input
if use_response_api then
for _, tool_call in ipairs(tool_calls) do
table.insert(messages, {
type = "function_call",
call_id = tool_call.id,
name = tool_call["function"].name,
arguments = tool_call["function"].arguments,
})
end
else
-- Chat Completions API format
local last_message = messages[#messages]
if
last_message
and last_message.role == self.role_map["assistant"]
and last_message.tool_calls
then
last_message.tool_calls = vim.list_extend(last_message.tool_calls, tool_calls)
if not last_message.content then
last_message.content = ""
end
else
table.insert(
messages,
{ role = self.role_map["assistant"], tool_calls = tool_calls, content = "" }
)
end
end
end
-- If support_previous_response_id is true, Response API manages function call history
-- So we can skip adding tool_calls to input messages
end
if #tool_results > 0 then
for _, tool_result in ipairs(tool_results) do
-- Response API uses different format for function outputs
if use_response_api then
table.insert(messages, {
type = "function_call_output",
call_id = tool_result.tool_call_id,
output = tool_result.content or "",
})
else
table.insert(
messages,
{
role = "tool",
tool_call_id = tool_result.tool_call_id,
content = tool_result.content or "",
}
)
end
end
end
end
end
end)
if Config.behaviour.support_paste_from_clipboard and opts.image_paths and #opts.image_paths > 0 then
local message_content = messages[#messages].content
if type(message_content) ~= "table" or message_content[1] == nil then
message_content = { { type = "text", text = message_content } }
end
for _, image_path in ipairs(opts.image_paths) do
table.insert(message_content, {
type = "image_url",
image_url = {
url = "data:image/png;base64," .. Clipboard.get_base64_content(image_path),
},
})
end
messages[#messages].content = message_content
end
local final_messages = {}
local prev_role = nil
local prev_type = nil
vim.iter(messages):each(function(message)
local role = message.role
if
role == prev_role
and role ~= "tool"
and prev_type ~= "function_call"
and prev_type ~= "function_call_output"
then
if role == self.role_map["assistant"] then
table.insert(final_messages, { role = self.role_map["user"], content = "Ok" })
else
table.insert(final_messages, { role = self.role_map["assistant"], content = "Ok, I understand." })
end
else
if role == "user" and prev_role == "tool" and M.is_mistral(provider_conf.endpoint) then
table.insert(final_messages, { role = self.role_map["assistant"], content = "Ok, I understand." })
end
end
prev_role = role
prev_type = message.type
table.insert(final_messages, message)
end)
return final_messages
end
function M:finish_pending_messages(ctx, opts)
if ctx.content ~= nil and ctx.content ~= "" then
self:add_text_message(ctx, "", "generated", opts)
end
if ctx.tool_use_map then
for _, tool_use in pairs(ctx.tool_use_map) do
if tool_use.state == "generating" then
self:add_tool_use_message(ctx, tool_use, "generated", opts)
end
end
end
end
local llm_tool_names = nil
function M:add_text_message(ctx, text, state, opts)
if llm_tool_names == nil then
llm_tool_names = LlmTools.get_tool_names()
end
if ctx.content == nil then
ctx.content = ""
end
ctx.content = ctx.content .. text
local content =
ctx.content:gsub("<tool_code>", ""):gsub("</tool_code>", ""):gsub("<tool_call>", ""):gsub("</tool_call>", "")
ctx.content = content
local msg = HistoryMessage:new("assistant", ctx.content, {
state = state,
uuid = ctx.content_uuid,
original_content = ctx.content,
})
ctx.content_uuid = msg.uuid
local msgs = { msg }
local xml_content = ctx.content
local xml_lines = vim.split(xml_content, "\n")
local cleaned_xml_lines = {}
local prev_tool_name = nil
for _, line in ipairs(xml_lines) do
if line:match("<tool_name>") then
local tool_name = line:match("<tool_name>(.*)</tool_name>")
if tool_name then
prev_tool_name = tool_name
end
elseif line:match("<parameters>") then
if prev_tool_name then
table.insert(cleaned_xml_lines, "<" .. prev_tool_name .. ">")
end
goto continue
elseif line:match("</parameters>") then
if prev_tool_name then
table.insert(cleaned_xml_lines, "</" .. prev_tool_name .. ">")
end
goto continue
end
table.insert(cleaned_xml_lines, line)
::continue::
end
local cleaned_xml_content = table.concat(cleaned_xml_lines, "\n")
local xml = ReActParser.parse(cleaned_xml_content)
if xml and #xml > 0 then
local new_content_list = {}
local xml_md_openned = false
for idx, item in ipairs(xml) do
if item.type == "text" then
local cleaned_lines = {}
local lines = vim.split(item.text, "\n")
for _, line in ipairs(lines) do
if line:match("^```xml") or line:match("^```tool_code") or line:match("^```tool_use") then
xml_md_openned = true
elseif line:match("^```$") then
if xml_md_openned then
xml_md_openned = false
else
table.insert(cleaned_lines, line)
end
else
table.insert(cleaned_lines, line)
end
end
table.insert(new_content_list, table.concat(cleaned_lines, "\n"))
goto continue
end
if not vim.tbl_contains(llm_tool_names, item.tool_name) then
goto continue
end
local input = {}
for k, v in pairs(item.tool_input or {}) do
local ok, jsn = pcall(vim.json.decode, v)
if ok and jsn then
input[k] = jsn
else
input[k] = v
end
end
if next(input) ~= nil then
local msg_uuid = ctx.content_uuid .. "-" .. idx
local tool_use_id = msg_uuid
local tool_message_state = item.partial and "generating" or "generated"
local msg_ = HistoryMessage:new("assistant", {
type = "tool_use",
name = item.tool_name,
id = tool_use_id,
input = input,
}, {
state = tool_message_state,
uuid = msg_uuid,
turn_id = ctx.turn_id,
})
msgs[#msgs + 1] = msg_
ctx.tool_use_map = ctx.tool_use_map or {}
local input_json = type(input) == "string" and input or vim.json.encode(input)
local exists = false
for _, tool_use in pairs(ctx.tool_use_map) do
if tool_use.id == tool_use_id then
tool_use.input_json = input_json
exists = true
end
end
if not exists then
local tool_key = tostring(vim.tbl_count(ctx.tool_use_map))
ctx.tool_use_map[tool_key] = {
uuid = tool_use_id,
id = tool_use_id,
name = item.tool_name,
input_json = input_json,
state = "generating",
}
end
opts.on_stop({ reason = "tool_use", streaming_tool_use = item.partial })
end
::continue::
end
msg.message.content = table.concat(new_content_list, "\n"):gsub("\n+$", "\n")
end
if opts.on_messages_add then
opts.on_messages_add(msgs)
end
end
function M:add_thinking_message(ctx, text, state, opts)
if ctx.reasonging_content == nil then
ctx.reasonging_content = ""
end
ctx.reasonging_content = ctx.reasonging_content .. text
local msg = HistoryMessage:new("assistant", {
type = "thinking",
thinking = ctx.reasonging_content,
signature = "",
}, {
state = state,
uuid = ctx.reasonging_content_uuid,
turn_id = ctx.turn_id,
})
ctx.reasonging_content_uuid = msg.uuid
if opts.on_messages_add then
opts.on_messages_add({ msg })
end
end
function M:add_tool_use_message(ctx, tool_use, state, opts)
local jsn = JsonParser.parse(tool_use.input_json)
local msg = HistoryMessage:new("assistant", {
type = "tool_use",
name = tool_use.name,
id = tool_use.id,
input = jsn or {},
}, {
state = state,
uuid = tool_use.uuid,
turn_id = ctx.turn_id,
})
tool_use.uuid = msg.uuid
tool_use.state = state
if opts.on_messages_add then
opts.on_messages_add({ msg })
end
if state == "generating" then
opts.on_stop({ reason = "tool_use", streaming_tool_use = true })
end
end
function M:add_reasoning_message(ctx, reasoning_item, opts)
local msg = HistoryMessage:new("assistant", {
type = "reasoning",
id = reasoning_item.id,
encrypted_content = reasoning_item.encrypted_content,
summary = reasoning_item.summary,
}, {
state = "generated",
uuid = Utils.uuid(),
turn_id = ctx.turn_id,
})
if opts.on_messages_add then
opts.on_messages_add({ msg })
end
end
---@param usage avante.OpenAITokenUsage | nil
---@return avante.LLMTokenUsage | nil
function M.transform_openai_usage(usage)
if not usage then
return nil
end
if usage == vim.NIL then
return nil
end
---@type avante.LLMTokenUsage
local res = {
prompt_tokens = usage.prompt_tokens,
completion_tokens = usage.completion_tokens,
}
return res
end
function M:parse_response(ctx, data_stream, _, opts)
if data_stream:match('"%[DONE%]":') or data_stream == "[DONE]" then
self:finish_pending_messages(ctx, opts)
if ctx.tool_use_map and vim.tbl_count(ctx.tool_use_map) > 0 then
ctx.tool_use_map = {}
opts.on_stop({ reason = "tool_use" })
else
opts.on_stop({ reason = "complete" })
end
return
end
local jsn = vim.json.decode(data_stream)
-- Check if this is a Response API event (has 'type' field)
if jsn.type and type(jsn.type) == "string" then
-- Response API event-driven format
if jsn.type == "response.output_text.delta" then
-- Text content delta
if jsn.delta and jsn.delta ~= vim.NIL and jsn.delta ~= "" then
if opts.on_chunk then
opts.on_chunk(jsn.delta)
end
self:add_text_message(ctx, jsn.delta, "generating", opts)
end
elseif jsn.type == "response.reasoning_summary_text.delta" then
-- Reasoning summary delta
if jsn.delta and jsn.delta ~= vim.NIL and jsn.delta ~= "" then
if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then
ctx.returned_think_start_tag = true
if opts.on_chunk then
opts.on_chunk("<think>\n")
end
end
ctx.last_think_content = jsn.delta
self:add_thinking_message(ctx, jsn.delta, "generating", opts)
if opts.on_chunk then
opts.on_chunk(jsn.delta)
end
end
elseif jsn.type == "response.function_call_arguments.delta" then
-- Function call arguments delta
if jsn.delta and jsn.delta ~= vim.NIL and jsn.delta ~= "" then
if not ctx.tool_use_map then
ctx.tool_use_map = {}
end
local tool_key = tostring(jsn.output_index or 0)
if not ctx.tool_use_map[tool_key] then
ctx.tool_use_map[tool_key] = {
name = jsn.name or "",
id = jsn.call_id or "",
input_json = jsn.delta,
}
else
ctx.tool_use_map[tool_key].input_json = ctx.tool_use_map[tool_key].input_json .. jsn.delta
end
end
elseif jsn.type == "response.output_item.added" then
-- Output item added (could be function call or reasoning)
if jsn.item and jsn.item.type == "function_call" then
local tool_key = tostring(jsn.output_index or 0)
if not ctx.tool_use_map then
ctx.tool_use_map = {}
end
ctx.tool_use_map[tool_key] = {
name = jsn.item.name or "",
id = jsn.item.call_id or jsn.item.id or "",
input_json = "",
}
self:add_tool_use_message(ctx, ctx.tool_use_map[tool_key], "generating", opts)
elseif jsn.item and jsn.item.type == "reasoning" then
-- Add reasoning item to history
self:add_reasoning_message(ctx, jsn.item, opts)
end
elseif jsn.type == "response.output_item.done" then
-- Output item done (finalize function call)
if jsn.item and jsn.item.type == "function_call" then
local tool_key = tostring(jsn.output_index or 0)
if ctx.tool_use_map and ctx.tool_use_map[tool_key] then
local tool_use = ctx.tool_use_map[tool_key]
if jsn.item.arguments then
tool_use.input_json = jsn.item.arguments
end
self:add_tool_use_message(ctx, tool_use, "generated", opts)
end
end
elseif jsn.type == "response.completed" or jsn.type == "response.done" then
-- Response completed - save response.id for future requests
if jsn.response and jsn.response.id then
ctx.last_response_id = jsn.response.id
-- Store in provider for next request
self.last_response_id = jsn.response.id
end
if
ctx.returned_think_start_tag ~= nil
and (ctx.returned_think_end_tag == nil or not ctx.returned_think_end_tag)
then
ctx.returned_think_end_tag = true
if opts.on_chunk then
if
ctx.last_think_content
and ctx.last_think_content ~= vim.NIL
and ctx.last_think_content:sub(-1) ~= "\n"
then
opts.on_chunk("\n</think>\n")
else
opts.on_chunk("</think>\n")
end
end
self:add_thinking_message(ctx, "", "generated", opts)
end
self:finish_pending_messages(ctx, opts)
local usage = nil
if jsn.response and jsn.response.usage then
usage = self.transform_openai_usage(jsn.response.usage)
end
if ctx.tool_use_map and vim.tbl_count(ctx.tool_use_map) > 0 then
opts.on_stop({ reason = "tool_use", usage = usage })
else
opts.on_stop({ reason = "complete", usage = usage })
end
elseif jsn.type == "error" then
-- Error event
local error_msg = jsn.error and vim.inspect(jsn.error) or "Unknown error"
opts.on_stop({ reason = "error", error = error_msg })
end
return
end
-- Chat Completions API format (original code)
if jsn.usage and jsn.usage ~= vim.NIL then
if opts.update_tokens_usage then
local usage = self.transform_openai_usage(jsn.usage)
if usage then
opts.update_tokens_usage(usage)
end
end
end
if jsn.error and jsn.error ~= vim.NIL then
opts.on_stop({ reason = "error", error = vim.inspect(jsn.error) })
return
end
---@cast jsn AvanteOpenAIChatResponse
if not jsn.choices then
return
end
local choice = jsn.choices[1]
if not choice then
return
end
local delta = choice.delta
if not delta then
local provider_conf = Providers.parse_config(self)
if provider_conf.model:match("o1") then
delta = choice.message
end
end
if not delta then
return
end
if delta.reasoning_content and delta.reasoning_content ~= vim.NIL and delta.reasoning_content ~= "" then
if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then
ctx.returned_think_start_tag = true
if opts.on_chunk then
opts.on_chunk("<think>\n")
end
end
ctx.last_think_content = delta.reasoning_content
self:add_thinking_message(ctx, delta.reasoning_content, "generating", opts)
if opts.on_chunk then
opts.on_chunk(delta.reasoning_content)
end
elseif delta.reasoning and delta.reasoning ~= vim.NIL then
if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then
ctx.returned_think_start_tag = true
if opts.on_chunk then
opts.on_chunk("<think>\n")
end
end
ctx.last_think_content = delta.reasoning
self:add_thinking_message(ctx, delta.reasoning, "generating", opts)
if opts.on_chunk then
opts.on_chunk(delta.reasoning)
end
elseif delta.tool_calls and delta.tool_calls ~= vim.NIL then
local choice_index = choice.index or 0
for idx, tool_call in ipairs(delta.tool_calls) do
--- In Gemini's so-called OpenAI Compatible API, tool_call.index is nil, which is quite absurd! Therefore, a compatibility fix is needed here.
if tool_call.index == nil then
tool_call.index = choice_index + idx - 1
end
if not ctx.tool_use_map then
ctx.tool_use_map = {}
end
local tool_key = tostring(tool_call.index)
local prev_tool_key = tostring(tool_call.index - 1)
if not ctx.tool_use_map[tool_key] then
local prev_tool_use = ctx.tool_use_map[prev_tool_key]
if tool_call.index > 0 and prev_tool_use then
self:add_tool_use_message(ctx, prev_tool_use, "generated", opts)
end
local tool_use = {
name = tool_call["function"].name,
id = tool_call.id,
input_json = type(tool_call["function"].arguments) == "string" and tool_call["function"].arguments
or "",
}
ctx.tool_use_map[tool_key] = tool_use
self:add_tool_use_message(ctx, tool_use, "generating", opts)
else
local tool_use = ctx.tool_use_map[tool_key]
if tool_call["function"].arguments == vim.NIL then
tool_call["function"].arguments = ""
end
tool_use.input_json = tool_use.input_json .. tool_call["function"].arguments
-- self:add_tool_use_message(ctx, tool_use, "generating", opts)
end
end
elseif delta.content then
if
ctx.returned_think_start_tag ~= nil
and (ctx.returned_think_end_tag == nil or not ctx.returned_think_end_tag)
then
ctx.returned_think_end_tag = true
if opts.on_chunk then
if
ctx.last_think_content
and ctx.last_think_content ~= vim.NIL
and ctx.last_think_content:sub(-1) ~= "\n"
then
opts.on_chunk("\n</think>\n")
else
opts.on_chunk("</think>\n")
end
end
self:add_thinking_message(ctx, "", "generated", opts)
end
if delta.content ~= vim.NIL then
if opts.on_chunk then
opts.on_chunk(delta.content)
end
self:add_text_message(ctx, delta.content, "generating", opts)
end
end
if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" or choice.finish_reason == "length" then
self:finish_pending_messages(ctx, opts)
if ctx.tool_use_map and vim.tbl_count(ctx.tool_use_map) > 0 then
opts.on_stop({ reason = "tool_use", usage = self.transform_openai_usage(jsn.usage) })
else
opts.on_stop({ reason = "complete", usage = self.transform_openai_usage(jsn.usage) })
end
end
if choice.finish_reason == "tool_calls" then
self:finish_pending_messages(ctx, opts)
opts.on_stop({
reason = "tool_use",
usage = self.transform_openai_usage(jsn.usage),
})
end
end
function M:parse_response_without_stream(data, _, opts)
---@type AvanteOpenAIChatResponse
local json = vim.json.decode(data)
if json.choices and json.choices[1] then
local choice = json.choices[1]
if choice.message and choice.message.content then
if opts.on_chunk then
opts.on_chunk(choice.message.content)
end
self:add_text_message({}, choice.message.content, "generated", opts)
vim.schedule(function()
opts.on_stop({ reason = "complete" })
end)
end
end
end
---@param prompt_opts AvantePromptOptions
---@return AvanteCurlOutput|nil
function M:parse_curl_args(prompt_opts)
local provider_conf, request_body = Providers.parse_config(self)
local disable_tools = provider_conf.disable_tools or false
local headers = {
["Content-Type"] = "application/json",
}
if Providers.env.require_api_key(provider_conf) then
local api_key = self.parse_api_key()
if api_key == nil then
Utils.error(
Config.provider .. ": API key is not set, please set it in your environment variable or config file"
)
return nil
end
headers["Authorization"] = "Bearer " .. api_key
end
if M.is_openrouter(provider_conf.endpoint) then
headers["HTTP-Referer"] = "https://github.com/yetone/avante.nvim"
headers["X-Title"] = "Avante.nvim"
request_body.include_reasoning = true
end
self.set_allowed_params(provider_conf, request_body)
local use_response_api = Providers.resolve_use_response_api(provider_conf, prompt_opts)
local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true
local tools = nil
if not disable_tools and prompt_opts.tools and not use_ReAct_prompt then
tools = {}
for _, tool in ipairs(prompt_opts.tools) do
local transformed_tool = self:transform_tool(tool)
-- Response API uses flattened tool structure
if use_response_api then
-- Convert from {type: "function", function: {name, description, parameters}}
-- to {type: "function", name, description, parameters}
if transformed_tool.type == "function" and transformed_tool["function"] then
transformed_tool = {
type = "function",
name = transformed_tool["function"].name,
description = transformed_tool["function"].description,
parameters = transformed_tool["function"].parameters,
}
end
end
table.insert(tools, transformed_tool)
end
end
Utils.debug("endpoint", provider_conf.endpoint)
Utils.debug("model", provider_conf.model)
local stop = nil
if use_ReAct_prompt then
stop = { "</tool_use>" }
end
-- Determine endpoint path based on use_response_api
local endpoint_path = use_response_api and "/responses" or "/chat/completions"
local parsed_messages = self:parse_messages(prompt_opts)
-- Build base body
local base_body = {
model = provider_conf.model,
stop = stop,
stream = true,
tools = tools,
}
-- Response API uses 'input' instead of 'messages'
if use_response_api then
-- Check if we have tool results - if so, use previous_response_id
local has_function_outputs = false
for _, msg in ipairs(parsed_messages) do
if msg.type == "function_call_output" then
has_function_outputs = true
break
end
end
if has_function_outputs and self.last_response_id then
-- When sending function outputs, use previous_response_id
base_body.previous_response_id = self.last_response_id
-- Only send the function outputs, not the full history
local function_outputs = {}
for _, msg in ipairs(parsed_messages) do
if msg.type == "function_call_output" then
table.insert(function_outputs, msg)
end
end
base_body.input = function_outputs
-- Clear the stored response_id after using it
self.last_response_id = nil
else
-- Normal request without tool results
base_body.input = parsed_messages
end
-- Response API uses max_output_tokens instead of max_tokens/max_completion_tokens
if request_body.max_completion_tokens then
request_body.max_output_tokens = request_body.max_completion_tokens
request_body.max_completion_tokens = nil
end
if request_body.max_tokens then
request_body.max_output_tokens = request_body.max_tokens
request_body.max_tokens = nil
end
-- Response API doesn't use stream_options
base_body.stream_options = nil
else
base_body.messages = parsed_messages
base_body.stream_options = not M.is_mistral(provider_conf.endpoint) and {
include_usage = true,
} or nil
end
return {
url = Utils.url_join(provider_conf.endpoint, endpoint_path),
proxy = provider_conf.proxy,
insecure = provider_conf.allow_insecure,
headers = Utils.tbl_override(headers, self.extra_headers),
body = vim.tbl_deep_extend("force", base_body, request_body),
}
end
return M