adding claude.lua
This commit is contained in:
@@ -361,4 +361,148 @@ function M.format_messages_for_claude(messages)
|
||||
return formatted
|
||||
end
|
||||
|
||||
--- Generate with tool use support for agentic mode
|
||||
---@param messages table[] Conversation history
|
||||
---@param context table Context information
|
||||
---@param tool_definitions table Tool definitions
|
||||
---@param callback fun(response: table|nil, error: string|nil) Callback with raw response
|
||||
function M.generate_with_tools(messages, context, tool_definitions, callback)
|
||||
local api_key = get_api_key()
|
||||
if not api_key then
|
||||
callback(nil, "Claude API key not configured")
|
||||
return
|
||||
end
|
||||
|
||||
local tools_module = require("codetyper.agent.tools")
|
||||
local agent_prompts = require("codetyper.prompts.agent")
|
||||
|
||||
-- Build system prompt with agent instructions
|
||||
local system_prompt = llm.build_system_prompt(context)
|
||||
system_prompt = system_prompt .. "\n\n" .. agent_prompts.system
|
||||
system_prompt = system_prompt .. "\n\n" .. agent_prompts.tool_instructions
|
||||
|
||||
-- Build request body with tools
|
||||
local body = {
|
||||
model = get_model(),
|
||||
max_tokens = 4096,
|
||||
system = system_prompt,
|
||||
messages = M.format_messages_for_claude(messages),
|
||||
tools = tools_module.to_claude_format(),
|
||||
}
|
||||
|
||||
local json_body = vim.json.encode(body)
|
||||
|
||||
local cmd = {
|
||||
"curl",
|
||||
"-s",
|
||||
"-X", "POST",
|
||||
API_URL,
|
||||
"-H", "Content-Type: application/json",
|
||||
"-H", "x-api-key: " .. api_key,
|
||||
"-H", "anthropic-version: 2023-06-01",
|
||||
"-d", json_body,
|
||||
}
|
||||
|
||||
vim.fn.jobstart(cmd, {
|
||||
stdout_buffered = true,
|
||||
on_stdout = function(_, data)
|
||||
if not data or #data == 0 or (data[1] == "" and #data == 1) then
|
||||
return
|
||||
end
|
||||
|
||||
local response_text = table.concat(data, "\n")
|
||||
local ok, response = pcall(vim.json.decode, response_text)
|
||||
|
||||
if not ok then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Failed to parse Claude response")
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
if response.error then
|
||||
vim.schedule(function()
|
||||
callback(nil, response.error.message or "Claude API error")
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
-- Return raw response for parser to handle
|
||||
vim.schedule(function()
|
||||
callback(response, nil)
|
||||
end)
|
||||
end,
|
||||
on_stderr = function(_, data)
|
||||
if data and #data > 0 and data[1] ~= "" then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Claude API request failed: " .. table.concat(data, "\n"))
|
||||
end)
|
||||
end
|
||||
end,
|
||||
on_exit = function(_, code)
|
||||
if code ~= 0 then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Claude API request failed with code: " .. code)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
})
|
||||
end
|
||||
|
||||
--- Format messages for Claude API
|
||||
---@param messages table[] Internal message format
|
||||
---@return table[] Claude API message format
|
||||
function M.format_messages_for_claude(messages)
|
||||
local formatted = {}
|
||||
|
||||
for _, msg in ipairs(messages) do
|
||||
if msg.role == "user" then
|
||||
if type(msg.content) == "table" then
|
||||
-- Tool results
|
||||
table.insert(formatted, {
|
||||
role = "user",
|
||||
content = msg.content,
|
||||
})
|
||||
else
|
||||
table.insert(formatted, {
|
||||
role = "user",
|
||||
content = msg.content,
|
||||
})
|
||||
end
|
||||
elseif msg.role == "assistant" then
|
||||
-- Build content array for assistant messages
|
||||
local content = {}
|
||||
|
||||
-- Add text if present
|
||||
if msg.content and msg.content ~= "" then
|
||||
table.insert(content, {
|
||||
type = "text",
|
||||
text = msg.content,
|
||||
})
|
||||
end
|
||||
|
||||
-- Add tool uses if present
|
||||
if msg.tool_calls then
|
||||
for _, tool_call in ipairs(msg.tool_calls) do
|
||||
table.insert(content, {
|
||||
type = "tool_use",
|
||||
id = tool_call.id,
|
||||
name = tool_call.name,
|
||||
input = tool_call.parameters,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
if #content > 0 then
|
||||
table.insert(formatted, {
|
||||
role = "assistant",
|
||||
content = content,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return formatted
|
||||
end
|
||||
|
||||
return M
|
||||
|
||||
531
lua/codetyper/llm/copilot.lua
Normal file
531
lua/codetyper/llm/copilot.lua
Normal file
@@ -0,0 +1,531 @@
|
||||
---Reference implementation:
|
||||
---https://github.com/zbirenbaum/copilot.lua/blob/master/lua/copilot/auth.lua config file
|
||||
---https://github.com/zed-industries/zed/blob/ad43bbbf5eda59eba65309735472e0be58b4f7dd/crates/copilot/src/copilot_chat.rs#L272 for authorization
|
||||
---
|
||||
---@class CopilotToken
|
||||
---@field annotations_enabled boolean
|
||||
---@field chat_enabled boolean
|
||||
---@field chat_jetbrains_enabled boolean
|
||||
---@field code_quote_enabled boolean
|
||||
---@field codesearch boolean
|
||||
---@field copilotignore_enabled boolean
|
||||
---@field endpoints {api: string, ["origin-tracker"]: string, proxy: string, telemetry: string}
|
||||
---@field expires_at integer
|
||||
---@field individual boolean
|
||||
---@field nes_enabled boolean
|
||||
---@field prompt_8k boolean
|
||||
---@field public_suggestions string
|
||||
---@field refresh_in integer
|
||||
---@field sku string
|
||||
---@field snippy_load_test_enabled boolean
|
||||
---@field telemetry string
|
||||
---@field token string
|
||||
---@field tracking_id string
|
||||
---@field vsc_electron_fetcher boolean
|
||||
---@field xcode boolean
|
||||
---@field xcode_chat boolean
|
||||
|
||||
local curl = require("plenary.curl")
|
||||
|
||||
local Path = require("plenary.path")
|
||||
local Utils = require("avante.utils")
|
||||
local Providers = require("avante.providers")
|
||||
local OpenAI = require("avante.providers").openai
|
||||
|
||||
local H = {}
|
||||
|
||||
---@class AvanteProviderFunctor
|
||||
local M = {}
|
||||
|
||||
local copilot_path = vim.fn.stdpath("data") .. "/avante/github-copilot.json"
|
||||
local lockfile_path = vim.fn.stdpath("data") .. "/avante/copilot-timer.lock"
|
||||
|
||||
-- Lockfile management
|
||||
local function is_process_running(pid)
|
||||
local result = vim.uv.kill(pid, 0)
|
||||
if result ~= nil and result == 0 then
|
||||
return true
|
||||
else
|
||||
return false
|
||||
end
|
||||
end
|
||||
|
||||
local function try_acquire_timer_lock()
|
||||
local lockfile = Path:new(lockfile_path)
|
||||
|
||||
local tmp_lockfile = lockfile_path .. ".tmp." .. vim.fn.getpid()
|
||||
|
||||
Path:new(tmp_lockfile):write(tostring(vim.fn.getpid()), "w")
|
||||
|
||||
-- Check existing lock
|
||||
if lockfile:exists() then
|
||||
local content = lockfile:read()
|
||||
local pid = tonumber(content)
|
||||
if pid and is_process_running(pid) then
|
||||
os.remove(tmp_lockfile)
|
||||
return false -- Another instance is already managing
|
||||
end
|
||||
end
|
||||
|
||||
-- Attempt to take ownership
|
||||
local success = os.rename(tmp_lockfile, lockfile_path)
|
||||
if not success then
|
||||
os.remove(tmp_lockfile)
|
||||
return false
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
local function start_manager_check_timer()
|
||||
if M._manager_check_timer then
|
||||
M._manager_check_timer:stop()
|
||||
M._manager_check_timer:close()
|
||||
end
|
||||
|
||||
M._manager_check_timer = vim.uv.new_timer()
|
||||
M._manager_check_timer:start(
|
||||
30000,
|
||||
30000,
|
||||
vim.schedule_wrap(function()
|
||||
if not M._refresh_timer and try_acquire_timer_lock() then
|
||||
M.setup_timer()
|
||||
end
|
||||
end)
|
||||
)
|
||||
end
|
||||
|
||||
---@class OAuthToken
|
||||
---@field user string
|
||||
---@field oauth_token string
|
||||
---
|
||||
---@return string
|
||||
function H.get_oauth_token()
|
||||
local xdg_config = vim.fn.expand("$XDG_CONFIG_HOME")
|
||||
local os_name = Utils.get_os_name()
|
||||
---@type string
|
||||
local config_dir
|
||||
|
||||
if xdg_config and vim.fn.isdirectory(xdg_config) > 0 then
|
||||
config_dir = xdg_config
|
||||
elseif vim.tbl_contains({ "linux", "darwin" }, os_name) then
|
||||
config_dir = vim.fn.expand("~/.config")
|
||||
else
|
||||
config_dir = vim.fn.expand("~/AppData/Local")
|
||||
end
|
||||
|
||||
--- hosts.json (copilot.lua), apps.json (copilot.vim)
|
||||
---@type Path[]
|
||||
local paths = vim.iter({ "hosts.json", "apps.json" }):fold({}, function(acc, path)
|
||||
local yason = Path:new(config_dir):joinpath("github-copilot", path)
|
||||
if yason:exists() then
|
||||
table.insert(acc, yason)
|
||||
end
|
||||
return acc
|
||||
end)
|
||||
if #paths == 0 then
|
||||
error("You must setup copilot with either copilot.lua or copilot.vim", 2)
|
||||
end
|
||||
|
||||
local yason = paths[1]
|
||||
return vim
|
||||
.iter(
|
||||
---@type table<string, OAuthToken>
|
||||
---@diagnostic disable-next-line: param-type-mismatch
|
||||
vim.json.decode(yason:read())
|
||||
)
|
||||
:filter(function(k, _)
|
||||
return k:match("github.com")
|
||||
end)
|
||||
---@param acc {oauth_token: string}
|
||||
:fold({}, function(acc, _, v)
|
||||
acc.oauth_token = v.oauth_token
|
||||
return acc
|
||||
end)
|
||||
.oauth_token
|
||||
end
|
||||
|
||||
H.chat_auth_url = "https://api.github.com/copilot_internal/v2/token"
|
||||
function H.chat_completion_url(base_url)
|
||||
return Utils.url_join(base_url, "/chat/completions")
|
||||
end
|
||||
function H.response_url(base_url)
|
||||
return Utils.url_join(base_url, "/responses")
|
||||
end
|
||||
|
||||
function H.refresh_token(async, force)
|
||||
if not M.state then
|
||||
error("internal initialization error")
|
||||
end
|
||||
|
||||
async = async == nil and true or async
|
||||
force = force or false
|
||||
|
||||
-- Do not refresh token if not forced or not expired
|
||||
if
|
||||
not force
|
||||
and M.state.github_token
|
||||
and M.state.github_token.expires_at
|
||||
and M.state.github_token.expires_at > math.floor(os.time())
|
||||
then
|
||||
return false
|
||||
end
|
||||
|
||||
local provider_conf = Providers.get_config("copilot")
|
||||
|
||||
local curl_opts = {
|
||||
headers = {
|
||||
["Authorization"] = "token " .. M.state.oauth_token,
|
||||
["Accept"] = "application/json",
|
||||
},
|
||||
timeout = provider_conf.timeout,
|
||||
proxy = provider_conf.proxy,
|
||||
insecure = provider_conf.allow_insecure,
|
||||
}
|
||||
|
||||
local function handle_response(response)
|
||||
if response.status == 200 then
|
||||
M.state.github_token = vim.json.decode(response.body)
|
||||
local file = Path:new(copilot_path)
|
||||
file:write(vim.json.encode(M.state.github_token), "w")
|
||||
if not vim.g.avante_login then
|
||||
vim.g.avante_login = true
|
||||
end
|
||||
|
||||
-- If triggered synchronously, reset timer
|
||||
if not async and M._refresh_timer then
|
||||
M.setup_timer()
|
||||
end
|
||||
|
||||
return true
|
||||
else
|
||||
error("Failed to get success response: " .. vim.inspect(response))
|
||||
return false
|
||||
end
|
||||
end
|
||||
|
||||
if async then
|
||||
curl.get(
|
||||
H.chat_auth_url,
|
||||
vim.tbl_deep_extend("force", {
|
||||
callback = handle_response,
|
||||
}, curl_opts)
|
||||
)
|
||||
else
|
||||
local response = curl.get(H.chat_auth_url, curl_opts)
|
||||
handle_response(response)
|
||||
end
|
||||
end
|
||||
|
||||
---@private
|
||||
---@class AvanteCopilotState
|
||||
---@field oauth_token string
|
||||
---@field github_token CopilotToken?
|
||||
M.state = nil
|
||||
|
||||
M.api_key_name = ""
|
||||
M.tokenizer_id = "gpt-4o"
|
||||
M.role_map = {
|
||||
user = "user",
|
||||
assistant = "assistant",
|
||||
}
|
||||
|
||||
function M:is_disable_stream()
|
||||
return false
|
||||
end
|
||||
|
||||
setmetatable(M, { __index = OpenAI })
|
||||
|
||||
function M:list_models()
|
||||
if M._model_list_cache then
|
||||
return M._model_list_cache
|
||||
end
|
||||
if not M._is_setup then
|
||||
M.setup()
|
||||
end
|
||||
-- refresh token synchronously, only if it has expired
|
||||
-- (this should rarely happen, as we refresh the token in the background)
|
||||
H.refresh_token(false, false)
|
||||
local provider_conf = Providers.parse_config(self)
|
||||
local headers = self:build_headers()
|
||||
local curl_opts = {
|
||||
headers = Utils.tbl_override(headers, self.extra_headers),
|
||||
timeout = provider_conf.timeout,
|
||||
proxy = provider_conf.proxy,
|
||||
insecure = provider_conf.allow_insecure,
|
||||
}
|
||||
|
||||
local function handle_response(response)
|
||||
if response.status == 200 then
|
||||
local body = vim.json.decode(response.body)
|
||||
-- ref: https://github.com/CopilotC-Nvim/CopilotChat.nvim/blob/16d897fd43d07e3b54478ccdb2f8a16e4df4f45a/lua/CopilotChat/config/providers.lua#L171-L187
|
||||
local models = vim.iter(body.data)
|
||||
:filter(function(model)
|
||||
return model.capabilities.type == "chat" and not vim.endswith(model.id, "paygo")
|
||||
end)
|
||||
:map(function(model)
|
||||
return {
|
||||
id = model.id,
|
||||
display_name = model.name,
|
||||
name = "copilot/" .. model.name .. " (" .. model.id .. ")",
|
||||
provider_name = "copilot",
|
||||
tokenizer = model.capabilities.tokenizer,
|
||||
max_input_tokens = model.capabilities.limits.max_prompt_tokens,
|
||||
max_output_tokens = model.capabilities.limits.max_output_tokens,
|
||||
policy = not model["policy"] or model["policy"]["state"] == "enabled",
|
||||
version = model.version,
|
||||
}
|
||||
end)
|
||||
:totable()
|
||||
M._model_list_cache = models
|
||||
return models
|
||||
else
|
||||
error("Failed to get success response: " .. vim.inspect(response))
|
||||
return {}
|
||||
end
|
||||
end
|
||||
|
||||
local response = curl.get((M.state.github_token.endpoints.api or "") .. "/models", curl_opts)
|
||||
return handle_response(response)
|
||||
end
|
||||
|
||||
function M:build_headers()
|
||||
return {
|
||||
["Authorization"] = "Bearer " .. M.state.github_token.token,
|
||||
["User-Agent"] = "GitHubCopilotChat/0.26.7",
|
||||
["Editor-Version"] = "vscode/1.105.1",
|
||||
["Editor-Plugin-Version"] = "copilot-chat/0.26.7",
|
||||
["Copilot-Integration-Id"] = "vscode-chat",
|
||||
["Openai-Intent"] = "conversation-edits",
|
||||
}
|
||||
end
|
||||
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
-- refresh token synchronously, only if it has expired
|
||||
-- (this should rarely happen, as we refresh the token in the background)
|
||||
H.refresh_token(false, false)
|
||||
|
||||
local provider_conf, request_body = Providers.parse_config(self)
|
||||
local use_response_api = Providers.resolve_use_response_api(provider_conf, prompt_opts)
|
||||
local disable_tools = provider_conf.disable_tools or false
|
||||
|
||||
-- Apply OpenAI's set_allowed_params for Response API compatibility
|
||||
OpenAI.set_allowed_params(provider_conf, request_body)
|
||||
|
||||
local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true
|
||||
|
||||
local tools = nil
|
||||
if not disable_tools and prompt_opts.tools and not use_ReAct_prompt then
|
||||
tools = {}
|
||||
for _, tool in ipairs(prompt_opts.tools) do
|
||||
local transformed_tool = OpenAI:transform_tool(tool)
|
||||
-- Response API uses flattened tool structure
|
||||
if use_response_api then
|
||||
if transformed_tool.type == "function" and transformed_tool["function"] then
|
||||
transformed_tool = {
|
||||
type = "function",
|
||||
name = transformed_tool["function"].name,
|
||||
description = transformed_tool["function"].description,
|
||||
parameters = transformed_tool["function"].parameters,
|
||||
}
|
||||
end
|
||||
end
|
||||
table.insert(tools, transformed_tool)
|
||||
end
|
||||
end
|
||||
|
||||
local headers = self:build_headers()
|
||||
|
||||
if prompt_opts.messages and #prompt_opts.messages > 0 then
|
||||
local last_message = prompt_opts.messages[#prompt_opts.messages]
|
||||
local initiator = last_message.role == "user" and "user" or "agent"
|
||||
headers["X-Initiator"] = initiator
|
||||
end
|
||||
|
||||
local parsed_messages = self:parse_messages(prompt_opts)
|
||||
|
||||
-- Build base body
|
||||
local base_body = {
|
||||
model = provider_conf.model,
|
||||
stream = true,
|
||||
tools = tools,
|
||||
}
|
||||
|
||||
-- Response API uses 'input' instead of 'messages'
|
||||
-- NOTE: Copilot doesn't support previous_response_id, always send full history
|
||||
if use_response_api then
|
||||
base_body.input = parsed_messages
|
||||
|
||||
-- Response API uses max_output_tokens instead of max_tokens/max_completion_tokens
|
||||
if request_body.max_completion_tokens then
|
||||
request_body.max_output_tokens = request_body.max_completion_tokens
|
||||
request_body.max_completion_tokens = nil
|
||||
end
|
||||
if request_body.max_tokens then
|
||||
request_body.max_output_tokens = request_body.max_tokens
|
||||
request_body.max_tokens = nil
|
||||
end
|
||||
-- Response API doesn't use stream_options
|
||||
base_body.stream_options = nil
|
||||
base_body.include = { "reasoning.encrypted_content" }
|
||||
base_body.reasoning = {
|
||||
summary = "detailed",
|
||||
}
|
||||
base_body.truncation = "disabled"
|
||||
else
|
||||
base_body.messages = parsed_messages
|
||||
base_body.stream_options = {
|
||||
include_usage = true,
|
||||
}
|
||||
end
|
||||
|
||||
local base_url = M.state.github_token.endpoints.api or provider_conf.endpoint
|
||||
local build_url = use_response_api and H.response_url or H.chat_completion_url
|
||||
|
||||
return {
|
||||
url = build_url(base_url),
|
||||
timeout = provider_conf.timeout,
|
||||
proxy = provider_conf.proxy,
|
||||
insecure = provider_conf.allow_insecure,
|
||||
headers = Utils.tbl_override(headers, self.extra_headers),
|
||||
body = vim.tbl_deep_extend("force", base_body, request_body),
|
||||
}
|
||||
end
|
||||
|
||||
M._refresh_timer = nil
|
||||
|
||||
function M.setup_timer()
|
||||
if M._refresh_timer then
|
||||
M._refresh_timer:stop()
|
||||
M._refresh_timer:close()
|
||||
end
|
||||
|
||||
-- Calculate time until token expires
|
||||
local now = math.floor(os.time())
|
||||
local expires_at = M.state.github_token and M.state.github_token.expires_at or now
|
||||
local time_until_expiry = math.max(0, expires_at - now)
|
||||
-- Refresh 2 minutes before expiration
|
||||
local initial_interval = math.max(0, (time_until_expiry - 120) * 1000)
|
||||
-- Regular interval of 28 minutes after the first refresh
|
||||
local repeat_interval = 28 * 60 * 1000
|
||||
|
||||
M._refresh_timer = vim.uv.new_timer()
|
||||
M._refresh_timer:start(
|
||||
initial_interval,
|
||||
repeat_interval,
|
||||
vim.schedule_wrap(function()
|
||||
H.refresh_token(true, true)
|
||||
end)
|
||||
)
|
||||
end
|
||||
|
||||
function M.setup_file_watcher()
|
||||
if M._file_watcher then
|
||||
return
|
||||
end
|
||||
|
||||
local copilot_token_file = Path:new(copilot_path)
|
||||
M._file_watcher = vim.uv.new_fs_event()
|
||||
|
||||
M._file_watcher:start(
|
||||
copilot_path,
|
||||
{},
|
||||
vim.schedule_wrap(function()
|
||||
-- Reload token from file
|
||||
if copilot_token_file:exists() then
|
||||
local ok, token = pcall(vim.json.decode, copilot_token_file:read())
|
||||
if ok then
|
||||
M.state.github_token = token
|
||||
end
|
||||
end
|
||||
end)
|
||||
)
|
||||
end
|
||||
|
||||
M._is_setup = false
|
||||
|
||||
function M.is_env_set()
|
||||
local ok = pcall(function()
|
||||
H.get_oauth_token()
|
||||
end)
|
||||
return ok
|
||||
end
|
||||
|
||||
function M.setup()
|
||||
local copilot_token_file = Path:new(copilot_path)
|
||||
|
||||
if not M.state then
|
||||
M.state = {
|
||||
github_token = nil,
|
||||
oauth_token = H.get_oauth_token(),
|
||||
}
|
||||
end
|
||||
|
||||
-- Load and validate existing token
|
||||
if copilot_token_file:exists() then
|
||||
local ok, token = pcall(vim.json.decode, copilot_token_file:read())
|
||||
if ok and token.expires_at and token.expires_at > math.floor(os.time()) then
|
||||
M.state.github_token = token
|
||||
end
|
||||
end
|
||||
|
||||
-- Setup timer management
|
||||
local timer_lock_acquired = try_acquire_timer_lock()
|
||||
if timer_lock_acquired then
|
||||
M.setup_timer()
|
||||
else
|
||||
vim.schedule(function()
|
||||
H.refresh_token(true, false)
|
||||
end)
|
||||
end
|
||||
|
||||
M.setup_file_watcher()
|
||||
|
||||
start_manager_check_timer()
|
||||
|
||||
require("avante.tokenizers").setup(M.tokenizer_id)
|
||||
vim.g.avante_login = true
|
||||
M._is_setup = true
|
||||
end
|
||||
|
||||
function M.cleanup()
|
||||
-- Cleanup refresh timer
|
||||
if M._refresh_timer then
|
||||
M._refresh_timer:stop()
|
||||
M._refresh_timer:close()
|
||||
M._refresh_timer = nil
|
||||
|
||||
-- Remove lockfile if we were the manager
|
||||
local lockfile = Path:new(lockfile_path)
|
||||
if lockfile:exists() then
|
||||
local content = lockfile:read()
|
||||
local pid = tonumber(content)
|
||||
if pid and pid == vim.fn.getpid() then
|
||||
lockfile:rm()
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Cleanup manager check timer
|
||||
if M._manager_check_timer then
|
||||
M._manager_check_timer:stop()
|
||||
M._manager_check_timer:close()
|
||||
M._manager_check_timer = nil
|
||||
end
|
||||
|
||||
-- Cleanup file watcher
|
||||
if M._file_watcher then
|
||||
---@diagnostic disable-next-line: param-type-mismatch
|
||||
M._file_watcher:stop()
|
||||
M._file_watcher = nil
|
||||
end
|
||||
end
|
||||
|
||||
-- Register cleanup on Neovim exit
|
||||
vim.api.nvim_create_autocmd("VimLeavePre", {
|
||||
callback = function()
|
||||
M.cleanup()
|
||||
end,
|
||||
})
|
||||
|
||||
return M
|
||||
361
lua/codetyper/llm/gemini.lua
Normal file
361
lua/codetyper/llm/gemini.lua
Normal file
@@ -0,0 +1,361 @@
|
||||
local Utils = require("avante.utils")
|
||||
local Providers = require("avante.providers")
|
||||
local Clipboard = require("avante.clipboard")
|
||||
local OpenAI = require("avante.providers").openai
|
||||
local Prompts = require("avante.utils.prompts")
|
||||
|
||||
---@class AvanteProviderFunctor
|
||||
local M = {}
|
||||
|
||||
M.api_key_name = "GEMINI_API_KEY"
|
||||
M.role_map = {
|
||||
user = "user",
|
||||
assistant = "model",
|
||||
}
|
||||
|
||||
function M:is_disable_stream()
|
||||
return false
|
||||
end
|
||||
|
||||
---@param tool AvanteLLMTool
|
||||
function M:transform_to_function_declaration(tool)
|
||||
local input_schema_properties, required = Utils.llm_tool_param_fields_to_json_schema(tool.param.fields)
|
||||
local parameters = nil
|
||||
if not vim.tbl_isempty(input_schema_properties) then
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = input_schema_properties,
|
||||
required = required,
|
||||
}
|
||||
end
|
||||
return {
|
||||
name = tool.name,
|
||||
description = tool.get_description and tool.get_description() or tool.description,
|
||||
parameters = parameters,
|
||||
}
|
||||
end
|
||||
|
||||
function M:parse_messages(opts)
|
||||
local provider_conf, _ = Providers.parse_config(self)
|
||||
local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true
|
||||
|
||||
local contents = {}
|
||||
local prev_role = nil
|
||||
|
||||
local tool_id_to_name = {}
|
||||
vim.iter(opts.messages):each(function(message)
|
||||
local role = message.role
|
||||
if role == prev_role then
|
||||
if role == M.role_map["user"] then
|
||||
table.insert(
|
||||
contents,
|
||||
{ role = M.role_map["assistant"], parts = {
|
||||
{ text = "Ok, I understand." },
|
||||
} }
|
||||
)
|
||||
else
|
||||
table.insert(contents, { role = M.role_map["user"], parts = {
|
||||
{ text = "Ok" },
|
||||
} })
|
||||
end
|
||||
end
|
||||
prev_role = role
|
||||
local parts = {}
|
||||
local content_items = message.content
|
||||
if type(content_items) == "string" then
|
||||
table.insert(parts, { text = content_items })
|
||||
elseif type(content_items) == "table" then
|
||||
---@cast content_items AvanteLLMMessageContentItem[]
|
||||
for _, item in ipairs(content_items) do
|
||||
if type(item) == "string" then
|
||||
table.insert(parts, { text = item })
|
||||
elseif type(item) == "table" and item.type == "text" then
|
||||
table.insert(parts, { text = item.text })
|
||||
elseif type(item) == "table" and item.type == "image" then
|
||||
table.insert(parts, {
|
||||
inline_data = {
|
||||
mime_type = "image/png",
|
||||
data = item.source.data,
|
||||
},
|
||||
})
|
||||
elseif type(item) == "table" and item.type == "tool_use" and not use_ReAct_prompt then
|
||||
tool_id_to_name[item.id] = item.name
|
||||
role = "model"
|
||||
table.insert(parts, {
|
||||
functionCall = {
|
||||
name = item.name,
|
||||
args = item.input,
|
||||
},
|
||||
})
|
||||
elseif type(item) == "table" and item.type == "tool_result" and not use_ReAct_prompt then
|
||||
role = "function"
|
||||
local ok, content = pcall(vim.json.decode, item.content)
|
||||
if not ok then
|
||||
content = item.content
|
||||
end
|
||||
-- item.name here refers to the name of the tool that was called,
|
||||
-- which is available in the tool_result content item prepared by llm.lua
|
||||
local tool_name = item.name
|
||||
if not tool_name then
|
||||
-- Fallback, though item.name should ideally always be present for tool_result
|
||||
tool_name = tool_id_to_name[item.tool_use_id]
|
||||
end
|
||||
table.insert(parts, {
|
||||
functionResponse = {
|
||||
name = tool_name,
|
||||
response = {
|
||||
name = tool_name, -- Gemini API requires the name in the response object as well
|
||||
content = content,
|
||||
},
|
||||
},
|
||||
})
|
||||
elseif type(item) == "table" and item.type == "thinking" then
|
||||
table.insert(parts, { text = item.thinking })
|
||||
elseif type(item) == "table" and item.type == "redacted_thinking" then
|
||||
table.insert(parts, { text = item.data })
|
||||
end
|
||||
end
|
||||
if not provider_conf.disable_tools and use_ReAct_prompt then
|
||||
if content_items[1].type == "tool_result" then
|
||||
local tool_use_msg = nil
|
||||
for _, msg_ in ipairs(opts.messages) do
|
||||
if type(msg_.content) == "table" and #msg_.content > 0 then
|
||||
if
|
||||
msg_.content[1].type == "tool_use"
|
||||
and msg_.content[1].id == content_items[1].tool_use_id
|
||||
then
|
||||
tool_use_msg = msg_
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
if tool_use_msg then
|
||||
table.insert(contents, {
|
||||
role = "model",
|
||||
parts = {
|
||||
{ text = Utils.tool_use_to_xml(tool_use_msg.content[1]) },
|
||||
},
|
||||
})
|
||||
role = "user"
|
||||
table.insert(parts, {
|
||||
text = "The result of tool use "
|
||||
.. Utils.tool_use_to_xml(tool_use_msg.content[1])
|
||||
.. " is:\n",
|
||||
})
|
||||
table.insert(parts, {
|
||||
text = content_items[1].content,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
if #parts > 0 then
|
||||
table.insert(contents, { role = M.role_map[role] or role, parts = parts })
|
||||
end
|
||||
end)
|
||||
|
||||
if Clipboard.support_paste_image() and opts.image_paths then
|
||||
for _, image_path in ipairs(opts.image_paths) do
|
||||
local image_data = {
|
||||
inline_data = {
|
||||
mime_type = "image/png",
|
||||
data = Clipboard.get_base64_content(image_path),
|
||||
},
|
||||
}
|
||||
|
||||
table.insert(contents[#contents].parts, image_data)
|
||||
end
|
||||
end
|
||||
|
||||
local system_prompt = opts.system_prompt
|
||||
|
||||
if use_ReAct_prompt then
|
||||
system_prompt = Prompts.get_ReAct_system_prompt(provider_conf, opts)
|
||||
end
|
||||
|
||||
return {
|
||||
systemInstruction = {
|
||||
role = "user",
|
||||
parts = {
|
||||
{
|
||||
text = system_prompt,
|
||||
},
|
||||
},
|
||||
},
|
||||
contents = contents,
|
||||
}
|
||||
end
|
||||
|
||||
--- Prepares the main request body for Gemini-like APIs.
|
||||
---@param provider_instance AvanteProviderFunctor The provider instance (self).
|
||||
---@param prompt_opts AvantePromptOptions Prompt options including messages, tools, system_prompt.
|
||||
---@param provider_conf table Provider configuration from config.lua (e.g., model, top-level temperature/max_tokens).
|
||||
---@param request_body_ table Request-specific overrides, typically from provider_conf.request_config_overrides.
|
||||
---@return table The fully constructed request body.
|
||||
function M.prepare_request_body(provider_instance, prompt_opts, provider_conf, request_body_)
|
||||
local request_body = {}
|
||||
request_body.generationConfig = request_body_.generationConfig or {}
|
||||
|
||||
local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true
|
||||
|
||||
if use_ReAct_prompt then
|
||||
request_body.generationConfig.stopSequences = { "</tool_use>" }
|
||||
end
|
||||
|
||||
local disable_tools = provider_conf.disable_tools or false
|
||||
|
||||
if not use_ReAct_prompt and not disable_tools and prompt_opts.tools then
|
||||
local function_declarations = {}
|
||||
for _, tool in ipairs(prompt_opts.tools) do
|
||||
table.insert(function_declarations, provider_instance:transform_to_function_declaration(tool))
|
||||
end
|
||||
|
||||
if #function_declarations > 0 then
|
||||
request_body.tools = {
|
||||
{
|
||||
functionDeclarations = function_declarations,
|
||||
},
|
||||
}
|
||||
end
|
||||
end
|
||||
|
||||
return vim.tbl_deep_extend("force", {}, provider_instance:parse_messages(prompt_opts), request_body)
|
||||
end
|
||||
|
||||
---@param usage avante.GeminiTokenUsage | nil
|
||||
---@return avante.LLMTokenUsage | nil
|
||||
function M.transform_gemini_usage(usage)
|
||||
if not usage then
|
||||
return nil
|
||||
end
|
||||
---@type avante.LLMTokenUsage
|
||||
local res = {
|
||||
prompt_tokens = usage.promptTokenCount,
|
||||
completion_tokens = usage.candidatesTokenCount,
|
||||
}
|
||||
return res
|
||||
end
|
||||
|
||||
function M:parse_response(ctx, data_stream, _, opts)
|
||||
local ok, jsn = pcall(vim.json.decode, data_stream)
|
||||
if not ok then
|
||||
opts.on_stop({ reason = "error", error = "Failed to parse JSON response: " .. tostring(jsn) })
|
||||
return
|
||||
end
|
||||
|
||||
if opts.update_tokens_usage and jsn.usageMetadata and jsn.usageMetadata ~= nil then
|
||||
local usage = M.transform_gemini_usage(jsn.usageMetadata)
|
||||
if usage ~= nil then
|
||||
opts.update_tokens_usage(usage)
|
||||
end
|
||||
end
|
||||
|
||||
-- Handle prompt feedback first, as it might indicate an overall issue with the prompt
|
||||
if jsn.promptFeedback and jsn.promptFeedback.blockReason then
|
||||
local feedback = jsn.promptFeedback
|
||||
OpenAI:finish_pending_messages(ctx, opts) -- Ensure any pending messages are cleared
|
||||
opts.on_stop({
|
||||
reason = "error",
|
||||
error = "Prompt blocked or filtered. Reason: " .. feedback.blockReason,
|
||||
details = feedback,
|
||||
})
|
||||
return
|
||||
end
|
||||
|
||||
if jsn.candidates and #jsn.candidates > 0 then
|
||||
local candidate = jsn.candidates[1]
|
||||
---@type AvanteLLMToolUse[]
|
||||
ctx.tool_use_list = ctx.tool_use_list or {}
|
||||
|
||||
-- Check if candidate.content and candidate.content.parts exist before iterating
|
||||
if candidate.content and candidate.content.parts then
|
||||
for _, part in ipairs(candidate.content.parts) do
|
||||
if part.text then
|
||||
if opts.on_chunk then
|
||||
opts.on_chunk(part.text)
|
||||
end
|
||||
OpenAI:add_text_message(ctx, part.text, "generating", opts)
|
||||
elseif part.functionCall then
|
||||
if not ctx.function_call_id then
|
||||
ctx.function_call_id = 0
|
||||
end
|
||||
ctx.function_call_id = ctx.function_call_id + 1
|
||||
local tool_use = {
|
||||
id = ctx.turn_id .. "-" .. tostring(ctx.function_call_id),
|
||||
name = part.functionCall.name,
|
||||
input_json = vim.json.encode(part.functionCall.args),
|
||||
}
|
||||
table.insert(ctx.tool_use_list, tool_use)
|
||||
OpenAI:add_tool_use_message(ctx, tool_use, "generated", opts)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Check for finishReason to determine if this candidate's stream is done.
|
||||
if candidate.finishReason then
|
||||
OpenAI:finish_pending_messages(ctx, opts)
|
||||
local reason_str = candidate.finishReason
|
||||
local stop_details = { finish_reason = reason_str }
|
||||
stop_details.usage = M.transform_gemini_usage(jsn.usageMetadata)
|
||||
|
||||
if reason_str == "TOOL_CODE" then
|
||||
-- Model indicates a tool-related stop.
|
||||
-- The tool_use list is added to the table in llm.lua
|
||||
opts.on_stop(vim.tbl_deep_extend("force", { reason = "tool_use" }, stop_details))
|
||||
elseif reason_str == "STOP" then
|
||||
if ctx.tool_use_list and #ctx.tool_use_list > 0 then
|
||||
-- Natural stop, but tools were found in this final chunk.
|
||||
opts.on_stop(vim.tbl_deep_extend("force", { reason = "tool_use" }, stop_details))
|
||||
else
|
||||
-- Natural stop, no tools in this final chunk.
|
||||
-- llm.lua will check its accumulated tools if tool_choice was active.
|
||||
opts.on_stop(vim.tbl_deep_extend("force", { reason = "complete" }, stop_details))
|
||||
end
|
||||
elseif reason_str == "MAX_TOKENS" then
|
||||
opts.on_stop(vim.tbl_deep_extend("force", { reason = "max_tokens" }, stop_details))
|
||||
elseif reason_str == "SAFETY" or reason_str == "RECITATION" then
|
||||
opts.on_stop(
|
||||
vim.tbl_deep_extend(
|
||||
"force",
|
||||
{ reason = "error", error = "Generation stopped: " .. reason_str },
|
||||
stop_details
|
||||
)
|
||||
)
|
||||
else -- OTHER, FINISH_REASON_UNSPECIFIED, or any other unhandled reason.
|
||||
opts.on_stop(
|
||||
vim.tbl_deep_extend(
|
||||
"force",
|
||||
{ reason = "error", error = "Generation stopped with unhandled reason: " .. reason_str },
|
||||
stop_details
|
||||
)
|
||||
)
|
||||
end
|
||||
end
|
||||
-- If no finishReason, it's an intermediate chunk; do not call on_stop.
|
||||
end
|
||||
end
|
||||
|
||||
---@param prompt_opts AvantePromptOptions
|
||||
---@return AvanteCurlOutput|nil
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
local provider_conf, request_body = Providers.parse_config(self)
|
||||
|
||||
local api_key = self:parse_api_key()
|
||||
if api_key == nil then
|
||||
Utils.error("Gemini: API key is not set. Please set " .. M.api_key_name)
|
||||
return nil
|
||||
end
|
||||
|
||||
return {
|
||||
url = Utils.url_join(
|
||||
provider_conf.endpoint,
|
||||
provider_conf.model .. ":streamGenerateContent?alt=sse&key=" .. api_key
|
||||
),
|
||||
proxy = provider_conf.proxy,
|
||||
insecure = provider_conf.allow_insecure,
|
||||
headers = Utils.tbl_override({ ["Content-Type"] = "application/json" }, self.extra_headers),
|
||||
body = M.prepare_request_body(self, prompt_opts, provider_conf, request_body),
|
||||
}
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -8,19 +8,19 @@ local llm = require("codetyper.llm")
|
||||
--- Get Ollama host from config
|
||||
---@return string Host URL
|
||||
local function get_host()
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
|
||||
return config.llm.ollama.host
|
||||
return config.llm.ollama.host
|
||||
end
|
||||
|
||||
--- Get model from config
|
||||
---@return string Model name
|
||||
local function get_model()
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
local codetyper = require("codetyper")
|
||||
local config = codetyper.get_config()
|
||||
|
||||
return config.llm.ollama.model
|
||||
return config.llm.ollama.model
|
||||
end
|
||||
|
||||
--- Build request body for Ollama API
|
||||
@@ -28,93 +28,96 @@ end
|
||||
---@param context table Context information
|
||||
---@return table Request body
|
||||
local function build_request_body(prompt, context)
|
||||
local system_prompt = llm.build_system_prompt(context)
|
||||
local system_prompt = llm.build_system_prompt(context)
|
||||
|
||||
return {
|
||||
model = get_model(),
|
||||
system = system_prompt,
|
||||
prompt = prompt,
|
||||
stream = false,
|
||||
options = {
|
||||
temperature = 0.2,
|
||||
num_predict = 4096,
|
||||
},
|
||||
}
|
||||
return {
|
||||
model = get_model(),
|
||||
system = system_prompt,
|
||||
prompt = prompt,
|
||||
stream = false,
|
||||
options = {
|
||||
temperature = 0.2,
|
||||
num_predict = 4096,
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
--- Make HTTP request to Ollama API
|
||||
---@param body table Request body
|
||||
---@param callback fun(response: string|nil, error: string|nil, usage: table|nil) Callback function
|
||||
local function make_request(body, callback)
|
||||
local host = get_host()
|
||||
local url = host .. "/api/generate"
|
||||
local json_body = vim.json.encode(body)
|
||||
local host = get_host()
|
||||
local url = host .. "/api/generate"
|
||||
local json_body = vim.json.encode(body)
|
||||
|
||||
local cmd = {
|
||||
"curl",
|
||||
"-s",
|
||||
"-X", "POST",
|
||||
url,
|
||||
"-H", "Content-Type: application/json",
|
||||
"-d", json_body,
|
||||
}
|
||||
local cmd = {
|
||||
"curl",
|
||||
"-s",
|
||||
"-X",
|
||||
"POST",
|
||||
url,
|
||||
"-H",
|
||||
"Content-Type: application/json",
|
||||
"-d",
|
||||
json_body,
|
||||
}
|
||||
|
||||
vim.fn.jobstart(cmd, {
|
||||
stdout_buffered = true,
|
||||
on_stdout = function(_, data)
|
||||
if not data or #data == 0 or (data[1] == "" and #data == 1) then
|
||||
return
|
||||
end
|
||||
vim.fn.jobstart(cmd, {
|
||||
stdout_buffered = true,
|
||||
on_stdout = function(_, data)
|
||||
if not data or #data == 0 or (data[1] == "" and #data == 1) then
|
||||
return
|
||||
end
|
||||
|
||||
local response_text = table.concat(data, "\n")
|
||||
local ok, response = pcall(vim.json.decode, response_text)
|
||||
local response_text = table.concat(data, "\n")
|
||||
local ok, response = pcall(vim.json.decode, response_text)
|
||||
|
||||
if not ok then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Failed to parse Ollama response", nil)
|
||||
end)
|
||||
return
|
||||
end
|
||||
if not ok then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Failed to parse Ollama response", nil)
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
if response.error then
|
||||
vim.schedule(function()
|
||||
callback(nil, response.error or "Ollama API error", nil)
|
||||
end)
|
||||
return
|
||||
end
|
||||
if response.error then
|
||||
vim.schedule(function()
|
||||
callback(nil, response.error or "Ollama API error", nil)
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
-- Extract usage info
|
||||
local usage = {
|
||||
prompt_tokens = response.prompt_eval_count or 0,
|
||||
response_tokens = response.eval_count or 0,
|
||||
}
|
||||
-- Extract usage info
|
||||
local usage = {
|
||||
prompt_tokens = response.prompt_eval_count or 0,
|
||||
response_tokens = response.eval_count or 0,
|
||||
}
|
||||
|
||||
if response.response then
|
||||
local code = llm.extract_code(response.response)
|
||||
vim.schedule(function()
|
||||
callback(code, nil, usage)
|
||||
end)
|
||||
else
|
||||
vim.schedule(function()
|
||||
callback(nil, "No response from Ollama", nil)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
on_stderr = function(_, data)
|
||||
if data and #data > 0 and data[1] ~= "" then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Ollama API request failed: " .. table.concat(data, "\n"), nil)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
on_exit = function(_, code)
|
||||
if code ~= 0 then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Ollama API request failed with code: " .. code, nil)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
})
|
||||
if response.response then
|
||||
local code = llm.extract_code(response.response)
|
||||
vim.schedule(function()
|
||||
callback(code, nil, usage)
|
||||
end)
|
||||
else
|
||||
vim.schedule(function()
|
||||
callback(nil, "No response from Ollama", nil)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
on_stderr = function(_, data)
|
||||
if data and #data > 0 and data[1] ~= "" then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Ollama API request failed: " .. table.concat(data, "\n"), nil)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
on_exit = function(_, code)
|
||||
if code ~= 0 then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Ollama API request failed with code: " .. code, nil)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
})
|
||||
end
|
||||
|
||||
--- Generate code using Ollama API
|
||||
@@ -122,111 +125,107 @@ end
|
||||
---@param context table Context information
|
||||
---@param callback fun(response: string|nil, error: string|nil) Callback function
|
||||
function M.generate(prompt, context, callback)
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local model = get_model()
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local model = get_model()
|
||||
|
||||
-- Log the request
|
||||
logs.request("ollama", model)
|
||||
logs.thinking("Building request body...")
|
||||
-- Log the request
|
||||
logs.request("ollama", model)
|
||||
logs.thinking("Building request body...")
|
||||
|
||||
local body = build_request_body(prompt, context)
|
||||
local body = build_request_body(prompt, context)
|
||||
|
||||
-- Estimate prompt tokens
|
||||
local prompt_estimate = logs.estimate_tokens(vim.json.encode(body))
|
||||
logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate))
|
||||
logs.thinking("Sending to Ollama API...")
|
||||
-- Estimate prompt tokens
|
||||
local prompt_estimate = logs.estimate_tokens(vim.json.encode(body))
|
||||
logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate))
|
||||
logs.thinking("Sending to Ollama API...")
|
||||
|
||||
utils.notify("Sending request to Ollama...", vim.log.levels.INFO)
|
||||
utils.notify("Sending request to Ollama...", vim.log.levels.INFO)
|
||||
|
||||
make_request(body, function(response, err, usage)
|
||||
if err then
|
||||
logs.error(err)
|
||||
utils.notify(err, vim.log.levels.ERROR)
|
||||
callback(nil, err)
|
||||
else
|
||||
-- Log token usage
|
||||
if usage then
|
||||
logs.response(
|
||||
usage.prompt_tokens or 0,
|
||||
usage.response_tokens or 0,
|
||||
"end_turn"
|
||||
)
|
||||
end
|
||||
logs.thinking("Response received, extracting code...")
|
||||
logs.info("Code generated successfully")
|
||||
utils.notify("Code generated successfully", vim.log.levels.INFO)
|
||||
callback(response, nil)
|
||||
end
|
||||
end)
|
||||
make_request(body, function(response, err, usage)
|
||||
if err then
|
||||
logs.error(err)
|
||||
utils.notify(err, vim.log.levels.ERROR)
|
||||
callback(nil, err)
|
||||
else
|
||||
-- Log token usage
|
||||
if usage then
|
||||
logs.response(usage.prompt_tokens or 0, usage.response_tokens or 0, "end_turn")
|
||||
end
|
||||
logs.thinking("Response received, extracting code...")
|
||||
logs.info("Code generated successfully")
|
||||
utils.notify("Code generated successfully", vim.log.levels.INFO)
|
||||
callback(response, nil)
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
--- Check if Ollama is reachable
|
||||
---@param callback fun(ok: boolean, error: string|nil) Callback function
|
||||
function M.health_check(callback)
|
||||
local host = get_host()
|
||||
local host = get_host()
|
||||
|
||||
local cmd = { "curl", "-s", host .. "/api/tags" }
|
||||
local cmd = { "curl", "-s", host .. "/api/tags" }
|
||||
|
||||
vim.fn.jobstart(cmd, {
|
||||
stdout_buffered = true,
|
||||
on_stdout = function(_, data)
|
||||
if data and #data > 0 and data[1] ~= "" then
|
||||
vim.schedule(function()
|
||||
callback(true, nil)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
on_exit = function(_, code)
|
||||
if code ~= 0 then
|
||||
vim.schedule(function()
|
||||
callback(false, "Cannot connect to Ollama at " .. host)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
})
|
||||
vim.fn.jobstart(cmd, {
|
||||
stdout_buffered = true,
|
||||
on_stdout = function(_, data)
|
||||
if data and #data > 0 and data[1] ~= "" then
|
||||
vim.schedule(function()
|
||||
callback(true, nil)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
on_exit = function(_, code)
|
||||
if code ~= 0 then
|
||||
vim.schedule(function()
|
||||
callback(false, "Cannot connect to Ollama at " .. host)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
})
|
||||
end
|
||||
|
||||
--- Check if Ollama is properly configured
|
||||
---@return boolean, string? Valid status and optional error message
|
||||
function M.validate()
|
||||
local host = get_host()
|
||||
if not host or host == "" then
|
||||
return false, "Ollama host not configured"
|
||||
end
|
||||
local model = get_model()
|
||||
if not model or model == "" then
|
||||
return false, "Ollama model not configured"
|
||||
end
|
||||
return true
|
||||
local host = get_host()
|
||||
if not host or host == "" then
|
||||
return false, "Ollama host not configured"
|
||||
end
|
||||
local model = get_model()
|
||||
if not model or model == "" then
|
||||
return false, "Ollama model not configured"
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
--- Build system prompt for agent mode with tool instructions
|
||||
---@param context table Context information
|
||||
---@return string System prompt
|
||||
local function build_agent_system_prompt(context)
|
||||
local agent_prompts = require("codetyper.prompts.agent")
|
||||
local tools_module = require("codetyper.agent.tools")
|
||||
local agent_prompts = require("codetyper.prompts.agent")
|
||||
local tools_module = require("codetyper.agent.tools")
|
||||
|
||||
local system_prompt = agent_prompts.system .. "\n\n"
|
||||
system_prompt = system_prompt .. tools_module.to_prompt_format() .. "\n\n"
|
||||
system_prompt = system_prompt .. agent_prompts.tool_instructions
|
||||
local system_prompt = agent_prompts.system .. "\n\n"
|
||||
system_prompt = system_prompt .. tools_module.to_prompt_format() .. "\n\n"
|
||||
system_prompt = system_prompt .. agent_prompts.tool_instructions
|
||||
|
||||
-- Add context about current file if available
|
||||
if context.file_path then
|
||||
system_prompt = system_prompt .. "\n\nCurrent working context:\n"
|
||||
system_prompt = system_prompt .. "- File: " .. context.file_path .. "\n"
|
||||
if context.language then
|
||||
system_prompt = system_prompt .. "- Language: " .. context.language .. "\n"
|
||||
end
|
||||
end
|
||||
-- Add context about current file if available
|
||||
if context.file_path then
|
||||
system_prompt = system_prompt .. "\n\nCurrent working context:\n"
|
||||
system_prompt = system_prompt .. "- File: " .. context.file_path .. "\n"
|
||||
if context.language then
|
||||
system_prompt = system_prompt .. "- Language: " .. context.language .. "\n"
|
||||
end
|
||||
end
|
||||
|
||||
-- Add project root info
|
||||
local root = utils.get_project_root()
|
||||
if root then
|
||||
system_prompt = system_prompt .. "- Project root: " .. root .. "\n"
|
||||
end
|
||||
-- Add project root info
|
||||
local root = utils.get_project_root()
|
||||
if root then
|
||||
system_prompt = system_prompt .. "- Project root: " .. root .. "\n"
|
||||
end
|
||||
|
||||
return system_prompt
|
||||
return system_prompt
|
||||
end
|
||||
|
||||
--- Build request body for Ollama API with tools (chat format)
|
||||
@@ -234,114 +233,117 @@ end
|
||||
---@param context table Context information
|
||||
---@return table Request body
|
||||
local function build_tools_request_body(messages, context)
|
||||
local system_prompt = build_agent_system_prompt(context)
|
||||
local system_prompt = build_agent_system_prompt(context)
|
||||
|
||||
-- Convert messages to Ollama chat format
|
||||
local ollama_messages = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
local content = msg.content
|
||||
-- Handle complex content (like tool results)
|
||||
if type(content) == "table" then
|
||||
local text_parts = {}
|
||||
for _, part in ipairs(content) do
|
||||
if part.type == "tool_result" then
|
||||
table.insert(text_parts, "[" .. (part.name or "tool") .. " result]: " .. (part.content or ""))
|
||||
elseif part.type == "text" then
|
||||
table.insert(text_parts, part.text or "")
|
||||
end
|
||||
end
|
||||
content = table.concat(text_parts, "\n")
|
||||
end
|
||||
-- Convert messages to Ollama chat format
|
||||
local ollama_messages = {}
|
||||
for _, msg in ipairs(messages) do
|
||||
local content = msg.content
|
||||
-- Handle complex content (like tool results)
|
||||
if type(content) == "table" then
|
||||
local text_parts = {}
|
||||
for _, part in ipairs(content) do
|
||||
if part.type == "tool_result" then
|
||||
table.insert(text_parts, "[" .. (part.name or "tool") .. " result]: " .. (part.content or ""))
|
||||
elseif part.type == "text" then
|
||||
table.insert(text_parts, part.text or "")
|
||||
end
|
||||
end
|
||||
content = table.concat(text_parts, "\n")
|
||||
end
|
||||
|
||||
table.insert(ollama_messages, {
|
||||
role = msg.role,
|
||||
content = content,
|
||||
})
|
||||
end
|
||||
table.insert(ollama_messages, {
|
||||
role = msg.role,
|
||||
content = content,
|
||||
})
|
||||
end
|
||||
|
||||
return {
|
||||
model = get_model(),
|
||||
messages = ollama_messages,
|
||||
system = system_prompt,
|
||||
stream = false,
|
||||
options = {
|
||||
temperature = 0.3,
|
||||
num_predict = 4096,
|
||||
},
|
||||
}
|
||||
return {
|
||||
model = get_model(),
|
||||
messages = ollama_messages,
|
||||
system = system_prompt,
|
||||
stream = false,
|
||||
options = {
|
||||
temperature = 0.3,
|
||||
num_predict = 4096,
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
--- Make HTTP request to Ollama chat API
|
||||
---@param body table Request body
|
||||
---@param callback fun(response: string|nil, error: string|nil, usage: table|nil) Callback function
|
||||
local function make_chat_request(body, callback)
|
||||
local host = get_host()
|
||||
local url = host .. "/api/chat"
|
||||
local json_body = vim.json.encode(body)
|
||||
local host = get_host()
|
||||
local url = host .. "/api/chat"
|
||||
local json_body = vim.json.encode(body)
|
||||
|
||||
local cmd = {
|
||||
"curl",
|
||||
"-s",
|
||||
"-X", "POST",
|
||||
url,
|
||||
"-H", "Content-Type: application/json",
|
||||
"-d", json_body,
|
||||
}
|
||||
local cmd = {
|
||||
"curl",
|
||||
"-s",
|
||||
"-X",
|
||||
"POST",
|
||||
url,
|
||||
"-H",
|
||||
"Content-Type: application/json",
|
||||
"-d",
|
||||
json_body,
|
||||
}
|
||||
|
||||
vim.fn.jobstart(cmd, {
|
||||
stdout_buffered = true,
|
||||
on_stdout = function(_, data)
|
||||
if not data or #data == 0 or (data[1] == "" and #data == 1) then
|
||||
return
|
||||
end
|
||||
vim.fn.jobstart(cmd, {
|
||||
stdout_buffered = true,
|
||||
on_stdout = function(_, data)
|
||||
if not data or #data == 0 or (data[1] == "" and #data == 1) then
|
||||
return
|
||||
end
|
||||
|
||||
local response_text = table.concat(data, "\n")
|
||||
local ok, response = pcall(vim.json.decode, response_text)
|
||||
local response_text = table.concat(data, "\n")
|
||||
local ok, response = pcall(vim.json.decode, response_text)
|
||||
|
||||
if not ok then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Failed to parse Ollama response", nil)
|
||||
end)
|
||||
return
|
||||
end
|
||||
if not ok then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Failed to parse Ollama response", nil)
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
if response.error then
|
||||
vim.schedule(function()
|
||||
callback(nil, response.error or "Ollama API error", nil)
|
||||
end)
|
||||
return
|
||||
end
|
||||
if response.error then
|
||||
vim.schedule(function()
|
||||
callback(nil, response.error or "Ollama API error", nil)
|
||||
end)
|
||||
return
|
||||
end
|
||||
|
||||
-- Extract usage info
|
||||
local usage = {
|
||||
prompt_tokens = response.prompt_eval_count or 0,
|
||||
response_tokens = response.eval_count or 0,
|
||||
}
|
||||
-- Extract usage info
|
||||
local usage = {
|
||||
prompt_tokens = response.prompt_eval_count or 0,
|
||||
response_tokens = response.eval_count or 0,
|
||||
}
|
||||
|
||||
-- Return the message content for agent parsing
|
||||
if response.message and response.message.content then
|
||||
vim.schedule(function()
|
||||
callback(response.message.content, nil, usage)
|
||||
end)
|
||||
else
|
||||
vim.schedule(function()
|
||||
callback(nil, "No response from Ollama", nil)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
on_stderr = function(_, data)
|
||||
if data and #data > 0 and data[1] ~= "" then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Ollama API request failed: " .. table.concat(data, "\n"), nil)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
on_exit = function(_, code)
|
||||
if code ~= 0 then
|
||||
-- Don't double-report errors
|
||||
end
|
||||
end,
|
||||
})
|
||||
-- Return the message content for agent parsing
|
||||
if response.message and response.message.content then
|
||||
vim.schedule(function()
|
||||
callback(response.message.content, nil, usage)
|
||||
end)
|
||||
else
|
||||
vim.schedule(function()
|
||||
callback(nil, "No response from Ollama", nil)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
on_stderr = function(_, data)
|
||||
if data and #data > 0 and data[1] ~= "" then
|
||||
vim.schedule(function()
|
||||
callback(nil, "Ollama API request failed: " .. table.concat(data, "\n"), nil)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
on_exit = function(_, code)
|
||||
if code ~= 0 then
|
||||
-- Don't double-report errors
|
||||
end
|
||||
end,
|
||||
})
|
||||
end
|
||||
|
||||
--- Generate response with tools using Ollama API
|
||||
@@ -350,50 +352,46 @@ end
|
||||
---@param tools table Tool definitions (embedded in prompt for Ollama)
|
||||
---@param callback fun(response: string|nil, error: string|nil) Callback function
|
||||
function M.generate_with_tools(messages, context, tools, callback)
|
||||
local logs = require("codetyper.agent.logs")
|
||||
local logs = require("codetyper.agent.logs")
|
||||
|
||||
-- Log the request
|
||||
local model = get_model()
|
||||
logs.request("ollama", model)
|
||||
logs.thinking("Preparing API request...")
|
||||
-- Log the request
|
||||
local model = get_model()
|
||||
logs.request("ollama", model)
|
||||
logs.thinking("Preparing API request...")
|
||||
|
||||
local body = build_tools_request_body(messages, context)
|
||||
local body = build_tools_request_body(messages, context)
|
||||
|
||||
-- Estimate prompt tokens
|
||||
local prompt_estimate = logs.estimate_tokens(vim.json.encode(body))
|
||||
logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate))
|
||||
-- Estimate prompt tokens
|
||||
local prompt_estimate = logs.estimate_tokens(vim.json.encode(body))
|
||||
logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate))
|
||||
|
||||
make_chat_request(body, function(response, err, usage)
|
||||
if err then
|
||||
logs.error(err)
|
||||
callback(nil, err)
|
||||
else
|
||||
-- Log token usage
|
||||
if usage then
|
||||
logs.response(
|
||||
usage.prompt_tokens or 0,
|
||||
usage.response_tokens or 0,
|
||||
"end_turn"
|
||||
)
|
||||
end
|
||||
make_chat_request(body, function(response, err, usage)
|
||||
if err then
|
||||
logs.error(err)
|
||||
callback(nil, err)
|
||||
else
|
||||
-- Log token usage
|
||||
if usage then
|
||||
logs.response(usage.prompt_tokens or 0, usage.response_tokens or 0, "end_turn")
|
||||
end
|
||||
|
||||
-- Log if response contains tool calls
|
||||
if response then
|
||||
local parser = require("codetyper.agent.parser")
|
||||
local parsed = parser.parse_ollama_response(response)
|
||||
if #parsed.tool_calls > 0 then
|
||||
for _, tc in ipairs(parsed.tool_calls) do
|
||||
logs.thinking("Tool call: " .. tc.name)
|
||||
end
|
||||
end
|
||||
if parsed.text and parsed.text ~= "" then
|
||||
logs.thinking("Response contains text")
|
||||
end
|
||||
end
|
||||
-- Log if response contains tool calls
|
||||
if response then
|
||||
local parser = require("codetyper.agent.parser")
|
||||
local parsed = parser.parse_ollama_response(response)
|
||||
if #parsed.tool_calls > 0 then
|
||||
for _, tc in ipairs(parsed.tool_calls) do
|
||||
logs.thinking("Tool call: " .. tc.name)
|
||||
end
|
||||
end
|
||||
if parsed.text and parsed.text ~= "" then
|
||||
logs.thinking("Response contains text")
|
||||
end
|
||||
end
|
||||
|
||||
callback(response, nil)
|
||||
end
|
||||
end)
|
||||
callback(response, nil)
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
return M
|
||||
|
||||
973
lua/codetyper/llm/openai.lua
Normal file
973
lua/codetyper/llm/openai.lua
Normal file
@@ -0,0 +1,973 @@
|
||||
local Utils = require("avante.utils")
|
||||
local Config = require("avante.config")
|
||||
local Clipboard = require("avante.clipboard")
|
||||
local Providers = require("avante.providers")
|
||||
local HistoryMessage = require("avante.history.message")
|
||||
local ReActParser = require("avante.libs.ReAct_parser2")
|
||||
local JsonParser = require("avante.libs.jsonparser")
|
||||
local Prompts = require("avante.utils.prompts")
|
||||
local LlmTools = require("avante.llm_tools")
|
||||
|
||||
---@class AvanteProviderFunctor
|
||||
local M = {}
|
||||
|
||||
M.api_key_name = "OPENAI_API_KEY"
|
||||
|
||||
M.role_map = {
|
||||
user = "user",
|
||||
assistant = "assistant",
|
||||
}
|
||||
|
||||
function M:is_disable_stream()
|
||||
return false
|
||||
end
|
||||
|
||||
---@param tool AvanteLLMTool
|
||||
---@return AvanteOpenAITool
|
||||
function M:transform_tool(tool)
|
||||
local input_schema_properties, required = Utils.llm_tool_param_fields_to_json_schema(tool.param.fields)
|
||||
---@type AvanteOpenAIToolFunctionParameters
|
||||
local parameters = {
|
||||
type = "object",
|
||||
properties = input_schema_properties,
|
||||
required = required,
|
||||
additionalProperties = false,
|
||||
}
|
||||
---@type AvanteOpenAITool
|
||||
local res = {
|
||||
type = "function",
|
||||
["function"] = {
|
||||
name = tool.name,
|
||||
description = tool.get_description and tool.get_description() or tool.description,
|
||||
parameters = parameters,
|
||||
},
|
||||
}
|
||||
return res
|
||||
end
|
||||
|
||||
function M.is_openrouter(url)
|
||||
return url:match("^https://openrouter%.ai/")
|
||||
end
|
||||
|
||||
function M.is_mistral(url)
|
||||
return url:match("^https://api%.mistral%.ai/")
|
||||
end
|
||||
|
||||
---@param opts AvantePromptOptions
|
||||
function M.get_user_message(opts)
|
||||
vim.deprecate("get_user_message", "parse_messages", "0.1.0", "avante.nvim")
|
||||
return table.concat(
|
||||
vim.iter(opts.messages)
|
||||
:filter(function(_, value)
|
||||
return value == nil or value.role ~= "user"
|
||||
end)
|
||||
:fold({}, function(acc, value)
|
||||
acc = vim.list_extend({}, acc)
|
||||
acc = vim.list_extend(acc, { value.content })
|
||||
return acc
|
||||
end),
|
||||
"\n"
|
||||
)
|
||||
end
|
||||
|
||||
function M.is_reasoning_model(model)
|
||||
return model
|
||||
and (string.match(model, "^o%d+") ~= nil or (string.match(model, "gpt%-5") ~= nil and model ~= "gpt-5-chat"))
|
||||
end
|
||||
|
||||
function M.set_allowed_params(provider_conf, request_body)
|
||||
local use_response_api = Providers.resolve_use_response_api(provider_conf, nil)
|
||||
if M.is_reasoning_model(provider_conf.model) then
|
||||
-- Reasoning models have specific parameter requirements
|
||||
request_body.temperature = 1
|
||||
-- Response API doesn't support temperature for reasoning models
|
||||
if use_response_api then
|
||||
request_body.temperature = nil
|
||||
end
|
||||
else
|
||||
request_body.reasoning_effort = nil
|
||||
request_body.reasoning = nil
|
||||
end
|
||||
-- If max_tokens is set in config, unset max_completion_tokens
|
||||
if request_body.max_tokens then
|
||||
request_body.max_completion_tokens = nil
|
||||
end
|
||||
|
||||
-- Handle Response API specific parameters
|
||||
if use_response_api then
|
||||
-- Convert reasoning_effort to reasoning object for Response API
|
||||
if request_body.reasoning_effort then
|
||||
request_body.reasoning = {
|
||||
effort = request_body.reasoning_effort,
|
||||
}
|
||||
request_body.reasoning_effort = nil
|
||||
end
|
||||
|
||||
-- Response API doesn't support some parameters
|
||||
-- Remove unsupported parameters for Response API
|
||||
local unsupported_params = {
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"logit_bias",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"n",
|
||||
}
|
||||
for _, param in ipairs(unsupported_params) do
|
||||
request_body[param] = nil
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function M:parse_messages(opts)
|
||||
local messages = {}
|
||||
local provider_conf, _ = Providers.parse_config(self)
|
||||
local use_response_api = Providers.resolve_use_response_api(provider_conf, opts)
|
||||
|
||||
local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true
|
||||
local system_prompt = opts.system_prompt
|
||||
|
||||
if use_ReAct_prompt then
|
||||
system_prompt = Prompts.get_ReAct_system_prompt(provider_conf, opts)
|
||||
end
|
||||
|
||||
if self.is_reasoning_model(provider_conf.model) then
|
||||
table.insert(messages, { role = "developer", content = system_prompt })
|
||||
else
|
||||
table.insert(messages, { role = "system", content = system_prompt })
|
||||
end
|
||||
|
||||
local has_tool_use = false
|
||||
|
||||
vim.iter(opts.messages):each(function(msg)
|
||||
if type(msg.content) == "string" then
|
||||
table.insert(messages, { role = self.role_map[msg.role], content = msg.content })
|
||||
elseif type(msg.content) == "table" then
|
||||
-- Check if this is a reasoning message (object with type "reasoning")
|
||||
if msg.content.type == "reasoning" then
|
||||
-- Add reasoning message directly (for Response API)
|
||||
table.insert(messages, {
|
||||
type = "reasoning",
|
||||
id = msg.content.id,
|
||||
encrypted_content = msg.content.encrypted_content,
|
||||
summary = msg.content.summary,
|
||||
})
|
||||
return
|
||||
end
|
||||
|
||||
local content = {}
|
||||
local tool_calls = {}
|
||||
local tool_results = {}
|
||||
for _, item in ipairs(msg.content) do
|
||||
if type(item) == "string" then
|
||||
table.insert(content, { type = "text", text = item })
|
||||
elseif item.type == "text" then
|
||||
table.insert(content, { type = "text", text = item.text })
|
||||
elseif item.type == "image" then
|
||||
table.insert(content, {
|
||||
type = "image_url",
|
||||
image_url = {
|
||||
url = "data:"
|
||||
.. item.source.media_type
|
||||
.. ";"
|
||||
.. item.source.type
|
||||
.. ","
|
||||
.. item.source.data,
|
||||
},
|
||||
})
|
||||
elseif item.type == "reasoning" then
|
||||
-- Add reasoning message directly (for Response API)
|
||||
table.insert(messages, {
|
||||
type = "reasoning",
|
||||
id = item.id,
|
||||
encrypted_content = item.encrypted_content,
|
||||
summary = item.summary,
|
||||
})
|
||||
elseif item.type == "tool_use" and not use_ReAct_prompt then
|
||||
has_tool_use = true
|
||||
table.insert(tool_calls, {
|
||||
id = item.id,
|
||||
type = "function",
|
||||
["function"] = { name = item.name, arguments = vim.json.encode(item.input) },
|
||||
})
|
||||
elseif item.type == "tool_result" and has_tool_use and not use_ReAct_prompt then
|
||||
table.insert(
|
||||
tool_results,
|
||||
{
|
||||
tool_call_id = item.tool_use_id,
|
||||
content = item.is_error and "Error: " .. item.content or item.content,
|
||||
}
|
||||
)
|
||||
end
|
||||
end
|
||||
if not provider_conf.disable_tools and use_ReAct_prompt then
|
||||
if msg.content[1].type == "tool_result" then
|
||||
local tool_use_msg = nil
|
||||
for _, msg_ in ipairs(opts.messages) do
|
||||
if type(msg_.content) == "table" and #msg_.content > 0 then
|
||||
if
|
||||
msg_.content[1].type == "tool_use"
|
||||
and msg_.content[1].id == msg.content[1].tool_use_id
|
||||
then
|
||||
tool_use_msg = msg_
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
if tool_use_msg then
|
||||
msg.role = "user"
|
||||
table.insert(content, {
|
||||
type = "text",
|
||||
text = "The result of tool use "
|
||||
.. Utils.tool_use_to_xml(tool_use_msg.content[1])
|
||||
.. " is:\n",
|
||||
})
|
||||
table.insert(content, {
|
||||
type = "text",
|
||||
text = msg.content[1].content,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
if #content > 0 then
|
||||
table.insert(messages, { role = self.role_map[msg.role], content = content })
|
||||
end
|
||||
if not provider_conf.disable_tools and not use_ReAct_prompt then
|
||||
if #tool_calls > 0 then
|
||||
-- Only skip tool_calls if using Response API with previous_response_id support
|
||||
-- Copilot uses Response API format but doesn't support previous_response_id
|
||||
local should_include_tool_calls = not use_response_api
|
||||
or not provider_conf.support_previous_response_id
|
||||
|
||||
if should_include_tool_calls then
|
||||
-- For Response API without previous_response_id support (like Copilot),
|
||||
-- convert tool_calls to function_call items in input
|
||||
if use_response_api then
|
||||
for _, tool_call in ipairs(tool_calls) do
|
||||
table.insert(messages, {
|
||||
type = "function_call",
|
||||
call_id = tool_call.id,
|
||||
name = tool_call["function"].name,
|
||||
arguments = tool_call["function"].arguments,
|
||||
})
|
||||
end
|
||||
else
|
||||
-- Chat Completions API format
|
||||
local last_message = messages[#messages]
|
||||
if
|
||||
last_message
|
||||
and last_message.role == self.role_map["assistant"]
|
||||
and last_message.tool_calls
|
||||
then
|
||||
last_message.tool_calls = vim.list_extend(last_message.tool_calls, tool_calls)
|
||||
|
||||
if not last_message.content then
|
||||
last_message.content = ""
|
||||
end
|
||||
else
|
||||
table.insert(
|
||||
messages,
|
||||
{ role = self.role_map["assistant"], tool_calls = tool_calls, content = "" }
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
-- If support_previous_response_id is true, Response API manages function call history
|
||||
-- So we can skip adding tool_calls to input messages
|
||||
end
|
||||
if #tool_results > 0 then
|
||||
for _, tool_result in ipairs(tool_results) do
|
||||
-- Response API uses different format for function outputs
|
||||
if use_response_api then
|
||||
table.insert(messages, {
|
||||
type = "function_call_output",
|
||||
call_id = tool_result.tool_call_id,
|
||||
output = tool_result.content or "",
|
||||
})
|
||||
else
|
||||
table.insert(
|
||||
messages,
|
||||
{
|
||||
role = "tool",
|
||||
tool_call_id = tool_result.tool_call_id,
|
||||
content = tool_result.content or "",
|
||||
}
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end)
|
||||
|
||||
if Config.behaviour.support_paste_from_clipboard and opts.image_paths and #opts.image_paths > 0 then
|
||||
local message_content = messages[#messages].content
|
||||
if type(message_content) ~= "table" or message_content[1] == nil then
|
||||
message_content = { { type = "text", text = message_content } }
|
||||
end
|
||||
for _, image_path in ipairs(opts.image_paths) do
|
||||
table.insert(message_content, {
|
||||
type = "image_url",
|
||||
image_url = {
|
||||
url = "data:image/png;base64," .. Clipboard.get_base64_content(image_path),
|
||||
},
|
||||
})
|
||||
end
|
||||
messages[#messages].content = message_content
|
||||
end
|
||||
|
||||
local final_messages = {}
|
||||
local prev_role = nil
|
||||
local prev_type = nil
|
||||
|
||||
vim.iter(messages):each(function(message)
|
||||
local role = message.role
|
||||
if
|
||||
role == prev_role
|
||||
and role ~= "tool"
|
||||
and prev_type ~= "function_call"
|
||||
and prev_type ~= "function_call_output"
|
||||
then
|
||||
if role == self.role_map["assistant"] then
|
||||
table.insert(final_messages, { role = self.role_map["user"], content = "Ok" })
|
||||
else
|
||||
table.insert(final_messages, { role = self.role_map["assistant"], content = "Ok, I understand." })
|
||||
end
|
||||
else
|
||||
if role == "user" and prev_role == "tool" and M.is_mistral(provider_conf.endpoint) then
|
||||
table.insert(final_messages, { role = self.role_map["assistant"], content = "Ok, I understand." })
|
||||
end
|
||||
end
|
||||
prev_role = role
|
||||
prev_type = message.type
|
||||
table.insert(final_messages, message)
|
||||
end)
|
||||
|
||||
return final_messages
|
||||
end
|
||||
|
||||
function M:finish_pending_messages(ctx, opts)
|
||||
if ctx.content ~= nil and ctx.content ~= "" then
|
||||
self:add_text_message(ctx, "", "generated", opts)
|
||||
end
|
||||
if ctx.tool_use_map then
|
||||
for _, tool_use in pairs(ctx.tool_use_map) do
|
||||
if tool_use.state == "generating" then
|
||||
self:add_tool_use_message(ctx, tool_use, "generated", opts)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local llm_tool_names = nil
|
||||
|
||||
function M:add_text_message(ctx, text, state, opts)
|
||||
if llm_tool_names == nil then
|
||||
llm_tool_names = LlmTools.get_tool_names()
|
||||
end
|
||||
if ctx.content == nil then
|
||||
ctx.content = ""
|
||||
end
|
||||
ctx.content = ctx.content .. text
|
||||
local content =
|
||||
ctx.content:gsub("<tool_code>", ""):gsub("</tool_code>", ""):gsub("<tool_call>", ""):gsub("</tool_call>", "")
|
||||
ctx.content = content
|
||||
local msg = HistoryMessage:new("assistant", ctx.content, {
|
||||
state = state,
|
||||
uuid = ctx.content_uuid,
|
||||
original_content = ctx.content,
|
||||
})
|
||||
ctx.content_uuid = msg.uuid
|
||||
local msgs = { msg }
|
||||
local xml_content = ctx.content
|
||||
local xml_lines = vim.split(xml_content, "\n")
|
||||
local cleaned_xml_lines = {}
|
||||
local prev_tool_name = nil
|
||||
for _, line in ipairs(xml_lines) do
|
||||
if line:match("<tool_name>") then
|
||||
local tool_name = line:match("<tool_name>(.*)</tool_name>")
|
||||
if tool_name then
|
||||
prev_tool_name = tool_name
|
||||
end
|
||||
elseif line:match("<parameters>") then
|
||||
if prev_tool_name then
|
||||
table.insert(cleaned_xml_lines, "<" .. prev_tool_name .. ">")
|
||||
end
|
||||
goto continue
|
||||
elseif line:match("</parameters>") then
|
||||
if prev_tool_name then
|
||||
table.insert(cleaned_xml_lines, "</" .. prev_tool_name .. ">")
|
||||
end
|
||||
goto continue
|
||||
end
|
||||
table.insert(cleaned_xml_lines, line)
|
||||
::continue::
|
||||
end
|
||||
local cleaned_xml_content = table.concat(cleaned_xml_lines, "\n")
|
||||
local xml = ReActParser.parse(cleaned_xml_content)
|
||||
if xml and #xml > 0 then
|
||||
local new_content_list = {}
|
||||
local xml_md_openned = false
|
||||
for idx, item in ipairs(xml) do
|
||||
if item.type == "text" then
|
||||
local cleaned_lines = {}
|
||||
local lines = vim.split(item.text, "\n")
|
||||
for _, line in ipairs(lines) do
|
||||
if line:match("^```xml") or line:match("^```tool_code") or line:match("^```tool_use") then
|
||||
xml_md_openned = true
|
||||
elseif line:match("^```$") then
|
||||
if xml_md_openned then
|
||||
xml_md_openned = false
|
||||
else
|
||||
table.insert(cleaned_lines, line)
|
||||
end
|
||||
else
|
||||
table.insert(cleaned_lines, line)
|
||||
end
|
||||
end
|
||||
table.insert(new_content_list, table.concat(cleaned_lines, "\n"))
|
||||
goto continue
|
||||
end
|
||||
if not vim.tbl_contains(llm_tool_names, item.tool_name) then
|
||||
goto continue
|
||||
end
|
||||
local input = {}
|
||||
for k, v in pairs(item.tool_input or {}) do
|
||||
local ok, jsn = pcall(vim.json.decode, v)
|
||||
if ok and jsn then
|
||||
input[k] = jsn
|
||||
else
|
||||
input[k] = v
|
||||
end
|
||||
end
|
||||
if next(input) ~= nil then
|
||||
local msg_uuid = ctx.content_uuid .. "-" .. idx
|
||||
local tool_use_id = msg_uuid
|
||||
local tool_message_state = item.partial and "generating" or "generated"
|
||||
local msg_ = HistoryMessage:new("assistant", {
|
||||
type = "tool_use",
|
||||
name = item.tool_name,
|
||||
id = tool_use_id,
|
||||
input = input,
|
||||
}, {
|
||||
state = tool_message_state,
|
||||
uuid = msg_uuid,
|
||||
turn_id = ctx.turn_id,
|
||||
})
|
||||
msgs[#msgs + 1] = msg_
|
||||
ctx.tool_use_map = ctx.tool_use_map or {}
|
||||
local input_json = type(input) == "string" and input or vim.json.encode(input)
|
||||
local exists = false
|
||||
for _, tool_use in pairs(ctx.tool_use_map) do
|
||||
if tool_use.id == tool_use_id then
|
||||
tool_use.input_json = input_json
|
||||
exists = true
|
||||
end
|
||||
end
|
||||
if not exists then
|
||||
local tool_key = tostring(vim.tbl_count(ctx.tool_use_map))
|
||||
ctx.tool_use_map[tool_key] = {
|
||||
uuid = tool_use_id,
|
||||
id = tool_use_id,
|
||||
name = item.tool_name,
|
||||
input_json = input_json,
|
||||
state = "generating",
|
||||
}
|
||||
end
|
||||
opts.on_stop({ reason = "tool_use", streaming_tool_use = item.partial })
|
||||
end
|
||||
::continue::
|
||||
end
|
||||
msg.message.content = table.concat(new_content_list, "\n"):gsub("\n+$", "\n")
|
||||
end
|
||||
if opts.on_messages_add then
|
||||
opts.on_messages_add(msgs)
|
||||
end
|
||||
end
|
||||
|
||||
function M:add_thinking_message(ctx, text, state, opts)
|
||||
if ctx.reasonging_content == nil then
|
||||
ctx.reasonging_content = ""
|
||||
end
|
||||
ctx.reasonging_content = ctx.reasonging_content .. text
|
||||
local msg = HistoryMessage:new("assistant", {
|
||||
type = "thinking",
|
||||
thinking = ctx.reasonging_content,
|
||||
signature = "",
|
||||
}, {
|
||||
state = state,
|
||||
uuid = ctx.reasonging_content_uuid,
|
||||
turn_id = ctx.turn_id,
|
||||
})
|
||||
ctx.reasonging_content_uuid = msg.uuid
|
||||
if opts.on_messages_add then
|
||||
opts.on_messages_add({ msg })
|
||||
end
|
||||
end
|
||||
|
||||
function M:add_tool_use_message(ctx, tool_use, state, opts)
|
||||
local jsn = JsonParser.parse(tool_use.input_json)
|
||||
local msg = HistoryMessage:new("assistant", {
|
||||
type = "tool_use",
|
||||
name = tool_use.name,
|
||||
id = tool_use.id,
|
||||
input = jsn or {},
|
||||
}, {
|
||||
state = state,
|
||||
uuid = tool_use.uuid,
|
||||
turn_id = ctx.turn_id,
|
||||
})
|
||||
tool_use.uuid = msg.uuid
|
||||
tool_use.state = state
|
||||
if opts.on_messages_add then
|
||||
opts.on_messages_add({ msg })
|
||||
end
|
||||
if state == "generating" then
|
||||
opts.on_stop({ reason = "tool_use", streaming_tool_use = true })
|
||||
end
|
||||
end
|
||||
|
||||
function M:add_reasoning_message(ctx, reasoning_item, opts)
|
||||
local msg = HistoryMessage:new("assistant", {
|
||||
type = "reasoning",
|
||||
id = reasoning_item.id,
|
||||
encrypted_content = reasoning_item.encrypted_content,
|
||||
summary = reasoning_item.summary,
|
||||
}, {
|
||||
state = "generated",
|
||||
uuid = Utils.uuid(),
|
||||
turn_id = ctx.turn_id,
|
||||
})
|
||||
if opts.on_messages_add then
|
||||
opts.on_messages_add({ msg })
|
||||
end
|
||||
end
|
||||
|
||||
---@param usage avante.OpenAITokenUsage | nil
|
||||
---@return avante.LLMTokenUsage | nil
|
||||
function M.transform_openai_usage(usage)
|
||||
if not usage then
|
||||
return nil
|
||||
end
|
||||
if usage == vim.NIL then
|
||||
return nil
|
||||
end
|
||||
---@type avante.LLMTokenUsage
|
||||
local res = {
|
||||
prompt_tokens = usage.prompt_tokens,
|
||||
completion_tokens = usage.completion_tokens,
|
||||
}
|
||||
return res
|
||||
end
|
||||
|
||||
function M:parse_response(ctx, data_stream, _, opts)
|
||||
if data_stream:match('"%[DONE%]":') or data_stream == "[DONE]" then
|
||||
self:finish_pending_messages(ctx, opts)
|
||||
if ctx.tool_use_map and vim.tbl_count(ctx.tool_use_map) > 0 then
|
||||
ctx.tool_use_map = {}
|
||||
opts.on_stop({ reason = "tool_use" })
|
||||
else
|
||||
opts.on_stop({ reason = "complete" })
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
local jsn = vim.json.decode(data_stream)
|
||||
|
||||
-- Check if this is a Response API event (has 'type' field)
|
||||
if jsn.type and type(jsn.type) == "string" then
|
||||
-- Response API event-driven format
|
||||
if jsn.type == "response.output_text.delta" then
|
||||
-- Text content delta
|
||||
if jsn.delta and jsn.delta ~= vim.NIL and jsn.delta ~= "" then
|
||||
if opts.on_chunk then
|
||||
opts.on_chunk(jsn.delta)
|
||||
end
|
||||
self:add_text_message(ctx, jsn.delta, "generating", opts)
|
||||
end
|
||||
elseif jsn.type == "response.reasoning_summary_text.delta" then
|
||||
-- Reasoning summary delta
|
||||
if jsn.delta and jsn.delta ~= vim.NIL and jsn.delta ~= "" then
|
||||
if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then
|
||||
ctx.returned_think_start_tag = true
|
||||
if opts.on_chunk then
|
||||
opts.on_chunk("<think>\n")
|
||||
end
|
||||
end
|
||||
ctx.last_think_content = jsn.delta
|
||||
self:add_thinking_message(ctx, jsn.delta, "generating", opts)
|
||||
if opts.on_chunk then
|
||||
opts.on_chunk(jsn.delta)
|
||||
end
|
||||
end
|
||||
elseif jsn.type == "response.function_call_arguments.delta" then
|
||||
-- Function call arguments delta
|
||||
if jsn.delta and jsn.delta ~= vim.NIL and jsn.delta ~= "" then
|
||||
if not ctx.tool_use_map then
|
||||
ctx.tool_use_map = {}
|
||||
end
|
||||
local tool_key = tostring(jsn.output_index or 0)
|
||||
if not ctx.tool_use_map[tool_key] then
|
||||
ctx.tool_use_map[tool_key] = {
|
||||
name = jsn.name or "",
|
||||
id = jsn.call_id or "",
|
||||
input_json = jsn.delta,
|
||||
}
|
||||
else
|
||||
ctx.tool_use_map[tool_key].input_json = ctx.tool_use_map[tool_key].input_json .. jsn.delta
|
||||
end
|
||||
end
|
||||
elseif jsn.type == "response.output_item.added" then
|
||||
-- Output item added (could be function call or reasoning)
|
||||
if jsn.item and jsn.item.type == "function_call" then
|
||||
local tool_key = tostring(jsn.output_index or 0)
|
||||
if not ctx.tool_use_map then
|
||||
ctx.tool_use_map = {}
|
||||
end
|
||||
ctx.tool_use_map[tool_key] = {
|
||||
name = jsn.item.name or "",
|
||||
id = jsn.item.call_id or jsn.item.id or "",
|
||||
input_json = "",
|
||||
}
|
||||
self:add_tool_use_message(ctx, ctx.tool_use_map[tool_key], "generating", opts)
|
||||
elseif jsn.item and jsn.item.type == "reasoning" then
|
||||
-- Add reasoning item to history
|
||||
self:add_reasoning_message(ctx, jsn.item, opts)
|
||||
end
|
||||
elseif jsn.type == "response.output_item.done" then
|
||||
-- Output item done (finalize function call)
|
||||
if jsn.item and jsn.item.type == "function_call" then
|
||||
local tool_key = tostring(jsn.output_index or 0)
|
||||
if ctx.tool_use_map and ctx.tool_use_map[tool_key] then
|
||||
local tool_use = ctx.tool_use_map[tool_key]
|
||||
if jsn.item.arguments then
|
||||
tool_use.input_json = jsn.item.arguments
|
||||
end
|
||||
self:add_tool_use_message(ctx, tool_use, "generated", opts)
|
||||
end
|
||||
end
|
||||
elseif jsn.type == "response.completed" or jsn.type == "response.done" then
|
||||
-- Response completed - save response.id for future requests
|
||||
if jsn.response and jsn.response.id then
|
||||
ctx.last_response_id = jsn.response.id
|
||||
-- Store in provider for next request
|
||||
self.last_response_id = jsn.response.id
|
||||
end
|
||||
if
|
||||
ctx.returned_think_start_tag ~= nil
|
||||
and (ctx.returned_think_end_tag == nil or not ctx.returned_think_end_tag)
|
||||
then
|
||||
ctx.returned_think_end_tag = true
|
||||
if opts.on_chunk then
|
||||
if
|
||||
ctx.last_think_content
|
||||
and ctx.last_think_content ~= vim.NIL
|
||||
and ctx.last_think_content:sub(-1) ~= "\n"
|
||||
then
|
||||
opts.on_chunk("\n</think>\n")
|
||||
else
|
||||
opts.on_chunk("</think>\n")
|
||||
end
|
||||
end
|
||||
self:add_thinking_message(ctx, "", "generated", opts)
|
||||
end
|
||||
self:finish_pending_messages(ctx, opts)
|
||||
local usage = nil
|
||||
if jsn.response and jsn.response.usage then
|
||||
usage = self.transform_openai_usage(jsn.response.usage)
|
||||
end
|
||||
if ctx.tool_use_map and vim.tbl_count(ctx.tool_use_map) > 0 then
|
||||
opts.on_stop({ reason = "tool_use", usage = usage })
|
||||
else
|
||||
opts.on_stop({ reason = "complete", usage = usage })
|
||||
end
|
||||
elseif jsn.type == "error" then
|
||||
-- Error event
|
||||
local error_msg = jsn.error and vim.inspect(jsn.error) or "Unknown error"
|
||||
opts.on_stop({ reason = "error", error = error_msg })
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
-- Chat Completions API format (original code)
|
||||
if jsn.usage and jsn.usage ~= vim.NIL then
|
||||
if opts.update_tokens_usage then
|
||||
local usage = self.transform_openai_usage(jsn.usage)
|
||||
if usage then
|
||||
opts.update_tokens_usage(usage)
|
||||
end
|
||||
end
|
||||
end
|
||||
if jsn.error and jsn.error ~= vim.NIL then
|
||||
opts.on_stop({ reason = "error", error = vim.inspect(jsn.error) })
|
||||
return
|
||||
end
|
||||
---@cast jsn AvanteOpenAIChatResponse
|
||||
if not jsn.choices then
|
||||
return
|
||||
end
|
||||
local choice = jsn.choices[1]
|
||||
if not choice then
|
||||
return
|
||||
end
|
||||
local delta = choice.delta
|
||||
if not delta then
|
||||
local provider_conf = Providers.parse_config(self)
|
||||
if provider_conf.model:match("o1") then
|
||||
delta = choice.message
|
||||
end
|
||||
end
|
||||
if not delta then
|
||||
return
|
||||
end
|
||||
if delta.reasoning_content and delta.reasoning_content ~= vim.NIL and delta.reasoning_content ~= "" then
|
||||
if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then
|
||||
ctx.returned_think_start_tag = true
|
||||
if opts.on_chunk then
|
||||
opts.on_chunk("<think>\n")
|
||||
end
|
||||
end
|
||||
ctx.last_think_content = delta.reasoning_content
|
||||
self:add_thinking_message(ctx, delta.reasoning_content, "generating", opts)
|
||||
if opts.on_chunk then
|
||||
opts.on_chunk(delta.reasoning_content)
|
||||
end
|
||||
elseif delta.reasoning and delta.reasoning ~= vim.NIL then
|
||||
if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then
|
||||
ctx.returned_think_start_tag = true
|
||||
if opts.on_chunk then
|
||||
opts.on_chunk("<think>\n")
|
||||
end
|
||||
end
|
||||
ctx.last_think_content = delta.reasoning
|
||||
self:add_thinking_message(ctx, delta.reasoning, "generating", opts)
|
||||
if opts.on_chunk then
|
||||
opts.on_chunk(delta.reasoning)
|
||||
end
|
||||
elseif delta.tool_calls and delta.tool_calls ~= vim.NIL then
|
||||
local choice_index = choice.index or 0
|
||||
for idx, tool_call in ipairs(delta.tool_calls) do
|
||||
--- In Gemini's so-called OpenAI Compatible API, tool_call.index is nil, which is quite absurd! Therefore, a compatibility fix is needed here.
|
||||
if tool_call.index == nil then
|
||||
tool_call.index = choice_index + idx - 1
|
||||
end
|
||||
if not ctx.tool_use_map then
|
||||
ctx.tool_use_map = {}
|
||||
end
|
||||
local tool_key = tostring(tool_call.index)
|
||||
local prev_tool_key = tostring(tool_call.index - 1)
|
||||
if not ctx.tool_use_map[tool_key] then
|
||||
local prev_tool_use = ctx.tool_use_map[prev_tool_key]
|
||||
if tool_call.index > 0 and prev_tool_use then
|
||||
self:add_tool_use_message(ctx, prev_tool_use, "generated", opts)
|
||||
end
|
||||
local tool_use = {
|
||||
name = tool_call["function"].name,
|
||||
id = tool_call.id,
|
||||
input_json = type(tool_call["function"].arguments) == "string" and tool_call["function"].arguments
|
||||
or "",
|
||||
}
|
||||
ctx.tool_use_map[tool_key] = tool_use
|
||||
self:add_tool_use_message(ctx, tool_use, "generating", opts)
|
||||
else
|
||||
local tool_use = ctx.tool_use_map[tool_key]
|
||||
if tool_call["function"].arguments == vim.NIL then
|
||||
tool_call["function"].arguments = ""
|
||||
end
|
||||
tool_use.input_json = tool_use.input_json .. tool_call["function"].arguments
|
||||
-- self:add_tool_use_message(ctx, tool_use, "generating", opts)
|
||||
end
|
||||
end
|
||||
elseif delta.content then
|
||||
if
|
||||
ctx.returned_think_start_tag ~= nil
|
||||
and (ctx.returned_think_end_tag == nil or not ctx.returned_think_end_tag)
|
||||
then
|
||||
ctx.returned_think_end_tag = true
|
||||
if opts.on_chunk then
|
||||
if
|
||||
ctx.last_think_content
|
||||
and ctx.last_think_content ~= vim.NIL
|
||||
and ctx.last_think_content:sub(-1) ~= "\n"
|
||||
then
|
||||
opts.on_chunk("\n</think>\n")
|
||||
else
|
||||
opts.on_chunk("</think>\n")
|
||||
end
|
||||
end
|
||||
self:add_thinking_message(ctx, "", "generated", opts)
|
||||
end
|
||||
if delta.content ~= vim.NIL then
|
||||
if opts.on_chunk then
|
||||
opts.on_chunk(delta.content)
|
||||
end
|
||||
self:add_text_message(ctx, delta.content, "generating", opts)
|
||||
end
|
||||
end
|
||||
if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" or choice.finish_reason == "length" then
|
||||
self:finish_pending_messages(ctx, opts)
|
||||
if ctx.tool_use_map and vim.tbl_count(ctx.tool_use_map) > 0 then
|
||||
opts.on_stop({ reason = "tool_use", usage = self.transform_openai_usage(jsn.usage) })
|
||||
else
|
||||
opts.on_stop({ reason = "complete", usage = self.transform_openai_usage(jsn.usage) })
|
||||
end
|
||||
end
|
||||
if choice.finish_reason == "tool_calls" then
|
||||
self:finish_pending_messages(ctx, opts)
|
||||
opts.on_stop({
|
||||
reason = "tool_use",
|
||||
usage = self.transform_openai_usage(jsn.usage),
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
function M:parse_response_without_stream(data, _, opts)
|
||||
---@type AvanteOpenAIChatResponse
|
||||
local json = vim.json.decode(data)
|
||||
if json.choices and json.choices[1] then
|
||||
local choice = json.choices[1]
|
||||
if choice.message and choice.message.content then
|
||||
if opts.on_chunk then
|
||||
opts.on_chunk(choice.message.content)
|
||||
end
|
||||
self:add_text_message({}, choice.message.content, "generated", opts)
|
||||
vim.schedule(function()
|
||||
opts.on_stop({ reason = "complete" })
|
||||
end)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
---@param prompt_opts AvantePromptOptions
|
||||
---@return AvanteCurlOutput|nil
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
local provider_conf, request_body = Providers.parse_config(self)
|
||||
local disable_tools = provider_conf.disable_tools or false
|
||||
|
||||
local headers = {
|
||||
["Content-Type"] = "application/json",
|
||||
}
|
||||
|
||||
if Providers.env.require_api_key(provider_conf) then
|
||||
local api_key = self.parse_api_key()
|
||||
if api_key == nil then
|
||||
Utils.error(
|
||||
Config.provider .. ": API key is not set, please set it in your environment variable or config file"
|
||||
)
|
||||
return nil
|
||||
end
|
||||
headers["Authorization"] = "Bearer " .. api_key
|
||||
end
|
||||
|
||||
if M.is_openrouter(provider_conf.endpoint) then
|
||||
headers["HTTP-Referer"] = "https://github.com/yetone/avante.nvim"
|
||||
headers["X-Title"] = "Avante.nvim"
|
||||
request_body.include_reasoning = true
|
||||
end
|
||||
|
||||
self.set_allowed_params(provider_conf, request_body)
|
||||
local use_response_api = Providers.resolve_use_response_api(provider_conf, prompt_opts)
|
||||
|
||||
local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true
|
||||
|
||||
local tools = nil
|
||||
if not disable_tools and prompt_opts.tools and not use_ReAct_prompt then
|
||||
tools = {}
|
||||
for _, tool in ipairs(prompt_opts.tools) do
|
||||
local transformed_tool = self:transform_tool(tool)
|
||||
-- Response API uses flattened tool structure
|
||||
if use_response_api then
|
||||
-- Convert from {type: "function", function: {name, description, parameters}}
|
||||
-- to {type: "function", name, description, parameters}
|
||||
if transformed_tool.type == "function" and transformed_tool["function"] then
|
||||
transformed_tool = {
|
||||
type = "function",
|
||||
name = transformed_tool["function"].name,
|
||||
description = transformed_tool["function"].description,
|
||||
parameters = transformed_tool["function"].parameters,
|
||||
}
|
||||
end
|
||||
end
|
||||
table.insert(tools, transformed_tool)
|
||||
end
|
||||
end
|
||||
|
||||
Utils.debug("endpoint", provider_conf.endpoint)
|
||||
Utils.debug("model", provider_conf.model)
|
||||
|
||||
local stop = nil
|
||||
if use_ReAct_prompt then
|
||||
stop = { "</tool_use>" }
|
||||
end
|
||||
|
||||
-- Determine endpoint path based on use_response_api
|
||||
local endpoint_path = use_response_api and "/responses" or "/chat/completions"
|
||||
|
||||
local parsed_messages = self:parse_messages(prompt_opts)
|
||||
|
||||
-- Build base body
|
||||
local base_body = {
|
||||
model = provider_conf.model,
|
||||
stop = stop,
|
||||
stream = true,
|
||||
tools = tools,
|
||||
}
|
||||
|
||||
-- Response API uses 'input' instead of 'messages'
|
||||
if use_response_api then
|
||||
-- Check if we have tool results - if so, use previous_response_id
|
||||
local has_function_outputs = false
|
||||
for _, msg in ipairs(parsed_messages) do
|
||||
if msg.type == "function_call_output" then
|
||||
has_function_outputs = true
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
if has_function_outputs and self.last_response_id then
|
||||
-- When sending function outputs, use previous_response_id
|
||||
base_body.previous_response_id = self.last_response_id
|
||||
-- Only send the function outputs, not the full history
|
||||
local function_outputs = {}
|
||||
for _, msg in ipairs(parsed_messages) do
|
||||
if msg.type == "function_call_output" then
|
||||
table.insert(function_outputs, msg)
|
||||
end
|
||||
end
|
||||
base_body.input = function_outputs
|
||||
-- Clear the stored response_id after using it
|
||||
self.last_response_id = nil
|
||||
else
|
||||
-- Normal request without tool results
|
||||
base_body.input = parsed_messages
|
||||
end
|
||||
|
||||
-- Response API uses max_output_tokens instead of max_tokens/max_completion_tokens
|
||||
if request_body.max_completion_tokens then
|
||||
request_body.max_output_tokens = request_body.max_completion_tokens
|
||||
request_body.max_completion_tokens = nil
|
||||
end
|
||||
if request_body.max_tokens then
|
||||
request_body.max_output_tokens = request_body.max_tokens
|
||||
request_body.max_tokens = nil
|
||||
end
|
||||
-- Response API doesn't use stream_options
|
||||
base_body.stream_options = nil
|
||||
else
|
||||
base_body.messages = parsed_messages
|
||||
base_body.stream_options = not M.is_mistral(provider_conf.endpoint) and {
|
||||
include_usage = true,
|
||||
} or nil
|
||||
end
|
||||
|
||||
return {
|
||||
url = Utils.url_join(provider_conf.endpoint, endpoint_path),
|
||||
proxy = provider_conf.proxy,
|
||||
insecure = provider_conf.allow_insecure,
|
||||
headers = Utils.tbl_override(headers, self.extra_headers),
|
||||
body = vim.tbl_deep_extend("force", base_body, request_body),
|
||||
}
|
||||
end
|
||||
|
||||
return M
|
||||
Reference in New Issue
Block a user