diff --git a/lua/codetyper/llm/claude.lua b/lua/codetyper/llm/claude.lua index c66d5e2..2413a3d 100644 --- a/lua/codetyper/llm/claude.lua +++ b/lua/codetyper/llm/claude.lua @@ -361,4 +361,148 @@ function M.format_messages_for_claude(messages) return formatted end +--- Generate with tool use support for agentic mode +---@param messages table[] Conversation history +---@param context table Context information +---@param tool_definitions table Tool definitions +---@param callback fun(response: table|nil, error: string|nil) Callback with raw response +function M.generate_with_tools(messages, context, tool_definitions, callback) + local api_key = get_api_key() + if not api_key then + callback(nil, "Claude API key not configured") + return + end + + local tools_module = require("codetyper.agent.tools") + local agent_prompts = require("codetyper.prompts.agent") + + -- Build system prompt with agent instructions + local system_prompt = llm.build_system_prompt(context) + system_prompt = system_prompt .. "\n\n" .. agent_prompts.system + system_prompt = system_prompt .. "\n\n" .. agent_prompts.tool_instructions + + -- Build request body with tools + local body = { + model = get_model(), + max_tokens = 4096, + system = system_prompt, + messages = M.format_messages_for_claude(messages), + tools = tools_module.to_claude_format(), + } + + local json_body = vim.json.encode(body) + + local cmd = { + "curl", + "-s", + "-X", "POST", + API_URL, + "-H", "Content-Type: application/json", + "-H", "x-api-key: " .. api_key, + "-H", "anthropic-version: 2023-06-01", + "-d", json_body, + } + + vim.fn.jobstart(cmd, { + stdout_buffered = true, + on_stdout = function(_, data) + if not data or #data == 0 or (data[1] == "" and #data == 1) then + return + end + + local response_text = table.concat(data, "\n") + local ok, response = pcall(vim.json.decode, response_text) + + if not ok then + vim.schedule(function() + callback(nil, "Failed to parse Claude response") + end) + return + end + + if response.error then + vim.schedule(function() + callback(nil, response.error.message or "Claude API error") + end) + return + end + + -- Return raw response for parser to handle + vim.schedule(function() + callback(response, nil) + end) + end, + on_stderr = function(_, data) + if data and #data > 0 and data[1] ~= "" then + vim.schedule(function() + callback(nil, "Claude API request failed: " .. table.concat(data, "\n")) + end) + end + end, + on_exit = function(_, code) + if code ~= 0 then + vim.schedule(function() + callback(nil, "Claude API request failed with code: " .. code) + end) + end + end, + }) +end + +--- Format messages for Claude API +---@param messages table[] Internal message format +---@return table[] Claude API message format +function M.format_messages_for_claude(messages) + local formatted = {} + + for _, msg in ipairs(messages) do + if msg.role == "user" then + if type(msg.content) == "table" then + -- Tool results + table.insert(formatted, { + role = "user", + content = msg.content, + }) + else + table.insert(formatted, { + role = "user", + content = msg.content, + }) + end + elseif msg.role == "assistant" then + -- Build content array for assistant messages + local content = {} + + -- Add text if present + if msg.content and msg.content ~= "" then + table.insert(content, { + type = "text", + text = msg.content, + }) + end + + -- Add tool uses if present + if msg.tool_calls then + for _, tool_call in ipairs(msg.tool_calls) do + table.insert(content, { + type = "tool_use", + id = tool_call.id, + name = tool_call.name, + input = tool_call.parameters, + }) + end + end + + if #content > 0 then + table.insert(formatted, { + role = "assistant", + content = content, + }) + end + end + end + + return formatted +end + return M diff --git a/lua/codetyper/llm/copilot.lua b/lua/codetyper/llm/copilot.lua new file mode 100644 index 0000000..946001c --- /dev/null +++ b/lua/codetyper/llm/copilot.lua @@ -0,0 +1,531 @@ +---Reference implementation: +---https://github.com/zbirenbaum/copilot.lua/blob/master/lua/copilot/auth.lua config file +---https://github.com/zed-industries/zed/blob/ad43bbbf5eda59eba65309735472e0be58b4f7dd/crates/copilot/src/copilot_chat.rs#L272 for authorization +--- +---@class CopilotToken +---@field annotations_enabled boolean +---@field chat_enabled boolean +---@field chat_jetbrains_enabled boolean +---@field code_quote_enabled boolean +---@field codesearch boolean +---@field copilotignore_enabled boolean +---@field endpoints {api: string, ["origin-tracker"]: string, proxy: string, telemetry: string} +---@field expires_at integer +---@field individual boolean +---@field nes_enabled boolean +---@field prompt_8k boolean +---@field public_suggestions string +---@field refresh_in integer +---@field sku string +---@field snippy_load_test_enabled boolean +---@field telemetry string +---@field token string +---@field tracking_id string +---@field vsc_electron_fetcher boolean +---@field xcode boolean +---@field xcode_chat boolean + +local curl = require("plenary.curl") + +local Path = require("plenary.path") +local Utils = require("avante.utils") +local Providers = require("avante.providers") +local OpenAI = require("avante.providers").openai + +local H = {} + +---@class AvanteProviderFunctor +local M = {} + +local copilot_path = vim.fn.stdpath("data") .. "/avante/github-copilot.json" +local lockfile_path = vim.fn.stdpath("data") .. "/avante/copilot-timer.lock" + +-- Lockfile management +local function is_process_running(pid) + local result = vim.uv.kill(pid, 0) + if result ~= nil and result == 0 then + return true + else + return false + end +end + +local function try_acquire_timer_lock() + local lockfile = Path:new(lockfile_path) + + local tmp_lockfile = lockfile_path .. ".tmp." .. vim.fn.getpid() + + Path:new(tmp_lockfile):write(tostring(vim.fn.getpid()), "w") + + -- Check existing lock + if lockfile:exists() then + local content = lockfile:read() + local pid = tonumber(content) + if pid and is_process_running(pid) then + os.remove(tmp_lockfile) + return false -- Another instance is already managing + end + end + + -- Attempt to take ownership + local success = os.rename(tmp_lockfile, lockfile_path) + if not success then + os.remove(tmp_lockfile) + return false + end + + return true +end + +local function start_manager_check_timer() + if M._manager_check_timer then + M._manager_check_timer:stop() + M._manager_check_timer:close() + end + + M._manager_check_timer = vim.uv.new_timer() + M._manager_check_timer:start( + 30000, + 30000, + vim.schedule_wrap(function() + if not M._refresh_timer and try_acquire_timer_lock() then + M.setup_timer() + end + end) + ) +end + +---@class OAuthToken +---@field user string +---@field oauth_token string +--- +---@return string +function H.get_oauth_token() + local xdg_config = vim.fn.expand("$XDG_CONFIG_HOME") + local os_name = Utils.get_os_name() + ---@type string + local config_dir + + if xdg_config and vim.fn.isdirectory(xdg_config) > 0 then + config_dir = xdg_config + elseif vim.tbl_contains({ "linux", "darwin" }, os_name) then + config_dir = vim.fn.expand("~/.config") + else + config_dir = vim.fn.expand("~/AppData/Local") + end + + --- hosts.json (copilot.lua), apps.json (copilot.vim) + ---@type Path[] + local paths = vim.iter({ "hosts.json", "apps.json" }):fold({}, function(acc, path) + local yason = Path:new(config_dir):joinpath("github-copilot", path) + if yason:exists() then + table.insert(acc, yason) + end + return acc + end) + if #paths == 0 then + error("You must setup copilot with either copilot.lua or copilot.vim", 2) + end + + local yason = paths[1] + return vim + .iter( + ---@type table + ---@diagnostic disable-next-line: param-type-mismatch + vim.json.decode(yason:read()) + ) + :filter(function(k, _) + return k:match("github.com") + end) + ---@param acc {oauth_token: string} + :fold({}, function(acc, _, v) + acc.oauth_token = v.oauth_token + return acc + end) + .oauth_token +end + +H.chat_auth_url = "https://api.github.com/copilot_internal/v2/token" +function H.chat_completion_url(base_url) + return Utils.url_join(base_url, "/chat/completions") +end +function H.response_url(base_url) + return Utils.url_join(base_url, "/responses") +end + +function H.refresh_token(async, force) + if not M.state then + error("internal initialization error") + end + + async = async == nil and true or async + force = force or false + + -- Do not refresh token if not forced or not expired + if + not force + and M.state.github_token + and M.state.github_token.expires_at + and M.state.github_token.expires_at > math.floor(os.time()) + then + return false + end + + local provider_conf = Providers.get_config("copilot") + + local curl_opts = { + headers = { + ["Authorization"] = "token " .. M.state.oauth_token, + ["Accept"] = "application/json", + }, + timeout = provider_conf.timeout, + proxy = provider_conf.proxy, + insecure = provider_conf.allow_insecure, + } + + local function handle_response(response) + if response.status == 200 then + M.state.github_token = vim.json.decode(response.body) + local file = Path:new(copilot_path) + file:write(vim.json.encode(M.state.github_token), "w") + if not vim.g.avante_login then + vim.g.avante_login = true + end + + -- If triggered synchronously, reset timer + if not async and M._refresh_timer then + M.setup_timer() + end + + return true + else + error("Failed to get success response: " .. vim.inspect(response)) + return false + end + end + + if async then + curl.get( + H.chat_auth_url, + vim.tbl_deep_extend("force", { + callback = handle_response, + }, curl_opts) + ) + else + local response = curl.get(H.chat_auth_url, curl_opts) + handle_response(response) + end +end + +---@private +---@class AvanteCopilotState +---@field oauth_token string +---@field github_token CopilotToken? +M.state = nil + +M.api_key_name = "" +M.tokenizer_id = "gpt-4o" +M.role_map = { + user = "user", + assistant = "assistant", +} + +function M:is_disable_stream() + return false +end + +setmetatable(M, { __index = OpenAI }) + +function M:list_models() + if M._model_list_cache then + return M._model_list_cache + end + if not M._is_setup then + M.setup() + end + -- refresh token synchronously, only if it has expired + -- (this should rarely happen, as we refresh the token in the background) + H.refresh_token(false, false) + local provider_conf = Providers.parse_config(self) + local headers = self:build_headers() + local curl_opts = { + headers = Utils.tbl_override(headers, self.extra_headers), + timeout = provider_conf.timeout, + proxy = provider_conf.proxy, + insecure = provider_conf.allow_insecure, + } + + local function handle_response(response) + if response.status == 200 then + local body = vim.json.decode(response.body) + -- ref: https://github.com/CopilotC-Nvim/CopilotChat.nvim/blob/16d897fd43d07e3b54478ccdb2f8a16e4df4f45a/lua/CopilotChat/config/providers.lua#L171-L187 + local models = vim.iter(body.data) + :filter(function(model) + return model.capabilities.type == "chat" and not vim.endswith(model.id, "paygo") + end) + :map(function(model) + return { + id = model.id, + display_name = model.name, + name = "copilot/" .. model.name .. " (" .. model.id .. ")", + provider_name = "copilot", + tokenizer = model.capabilities.tokenizer, + max_input_tokens = model.capabilities.limits.max_prompt_tokens, + max_output_tokens = model.capabilities.limits.max_output_tokens, + policy = not model["policy"] or model["policy"]["state"] == "enabled", + version = model.version, + } + end) + :totable() + M._model_list_cache = models + return models + else + error("Failed to get success response: " .. vim.inspect(response)) + return {} + end + end + + local response = curl.get((M.state.github_token.endpoints.api or "") .. "/models", curl_opts) + return handle_response(response) +end + +function M:build_headers() + return { + ["Authorization"] = "Bearer " .. M.state.github_token.token, + ["User-Agent"] = "GitHubCopilotChat/0.26.7", + ["Editor-Version"] = "vscode/1.105.1", + ["Editor-Plugin-Version"] = "copilot-chat/0.26.7", + ["Copilot-Integration-Id"] = "vscode-chat", + ["Openai-Intent"] = "conversation-edits", + } +end + +function M:parse_curl_args(prompt_opts) + -- refresh token synchronously, only if it has expired + -- (this should rarely happen, as we refresh the token in the background) + H.refresh_token(false, false) + + local provider_conf, request_body = Providers.parse_config(self) + local use_response_api = Providers.resolve_use_response_api(provider_conf, prompt_opts) + local disable_tools = provider_conf.disable_tools or false + + -- Apply OpenAI's set_allowed_params for Response API compatibility + OpenAI.set_allowed_params(provider_conf, request_body) + + local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true + + local tools = nil + if not disable_tools and prompt_opts.tools and not use_ReAct_prompt then + tools = {} + for _, tool in ipairs(prompt_opts.tools) do + local transformed_tool = OpenAI:transform_tool(tool) + -- Response API uses flattened tool structure + if use_response_api then + if transformed_tool.type == "function" and transformed_tool["function"] then + transformed_tool = { + type = "function", + name = transformed_tool["function"].name, + description = transformed_tool["function"].description, + parameters = transformed_tool["function"].parameters, + } + end + end + table.insert(tools, transformed_tool) + end + end + + local headers = self:build_headers() + + if prompt_opts.messages and #prompt_opts.messages > 0 then + local last_message = prompt_opts.messages[#prompt_opts.messages] + local initiator = last_message.role == "user" and "user" or "agent" + headers["X-Initiator"] = initiator + end + + local parsed_messages = self:parse_messages(prompt_opts) + + -- Build base body + local base_body = { + model = provider_conf.model, + stream = true, + tools = tools, + } + + -- Response API uses 'input' instead of 'messages' + -- NOTE: Copilot doesn't support previous_response_id, always send full history + if use_response_api then + base_body.input = parsed_messages + + -- Response API uses max_output_tokens instead of max_tokens/max_completion_tokens + if request_body.max_completion_tokens then + request_body.max_output_tokens = request_body.max_completion_tokens + request_body.max_completion_tokens = nil + end + if request_body.max_tokens then + request_body.max_output_tokens = request_body.max_tokens + request_body.max_tokens = nil + end + -- Response API doesn't use stream_options + base_body.stream_options = nil + base_body.include = { "reasoning.encrypted_content" } + base_body.reasoning = { + summary = "detailed", + } + base_body.truncation = "disabled" + else + base_body.messages = parsed_messages + base_body.stream_options = { + include_usage = true, + } + end + + local base_url = M.state.github_token.endpoints.api or provider_conf.endpoint + local build_url = use_response_api and H.response_url or H.chat_completion_url + + return { + url = build_url(base_url), + timeout = provider_conf.timeout, + proxy = provider_conf.proxy, + insecure = provider_conf.allow_insecure, + headers = Utils.tbl_override(headers, self.extra_headers), + body = vim.tbl_deep_extend("force", base_body, request_body), + } +end + +M._refresh_timer = nil + +function M.setup_timer() + if M._refresh_timer then + M._refresh_timer:stop() + M._refresh_timer:close() + end + + -- Calculate time until token expires + local now = math.floor(os.time()) + local expires_at = M.state.github_token and M.state.github_token.expires_at or now + local time_until_expiry = math.max(0, expires_at - now) + -- Refresh 2 minutes before expiration + local initial_interval = math.max(0, (time_until_expiry - 120) * 1000) + -- Regular interval of 28 minutes after the first refresh + local repeat_interval = 28 * 60 * 1000 + + M._refresh_timer = vim.uv.new_timer() + M._refresh_timer:start( + initial_interval, + repeat_interval, + vim.schedule_wrap(function() + H.refresh_token(true, true) + end) + ) +end + +function M.setup_file_watcher() + if M._file_watcher then + return + end + + local copilot_token_file = Path:new(copilot_path) + M._file_watcher = vim.uv.new_fs_event() + + M._file_watcher:start( + copilot_path, + {}, + vim.schedule_wrap(function() + -- Reload token from file + if copilot_token_file:exists() then + local ok, token = pcall(vim.json.decode, copilot_token_file:read()) + if ok then + M.state.github_token = token + end + end + end) + ) +end + +M._is_setup = false + +function M.is_env_set() + local ok = pcall(function() + H.get_oauth_token() + end) + return ok +end + +function M.setup() + local copilot_token_file = Path:new(copilot_path) + + if not M.state then + M.state = { + github_token = nil, + oauth_token = H.get_oauth_token(), + } + end + + -- Load and validate existing token + if copilot_token_file:exists() then + local ok, token = pcall(vim.json.decode, copilot_token_file:read()) + if ok and token.expires_at and token.expires_at > math.floor(os.time()) then + M.state.github_token = token + end + end + + -- Setup timer management + local timer_lock_acquired = try_acquire_timer_lock() + if timer_lock_acquired then + M.setup_timer() + else + vim.schedule(function() + H.refresh_token(true, false) + end) + end + + M.setup_file_watcher() + + start_manager_check_timer() + + require("avante.tokenizers").setup(M.tokenizer_id) + vim.g.avante_login = true + M._is_setup = true +end + +function M.cleanup() + -- Cleanup refresh timer + if M._refresh_timer then + M._refresh_timer:stop() + M._refresh_timer:close() + M._refresh_timer = nil + + -- Remove lockfile if we were the manager + local lockfile = Path:new(lockfile_path) + if lockfile:exists() then + local content = lockfile:read() + local pid = tonumber(content) + if pid and pid == vim.fn.getpid() then + lockfile:rm() + end + end + end + + -- Cleanup manager check timer + if M._manager_check_timer then + M._manager_check_timer:stop() + M._manager_check_timer:close() + M._manager_check_timer = nil + end + + -- Cleanup file watcher + if M._file_watcher then + ---@diagnostic disable-next-line: param-type-mismatch + M._file_watcher:stop() + M._file_watcher = nil + end +end + +-- Register cleanup on Neovim exit +vim.api.nvim_create_autocmd("VimLeavePre", { + callback = function() + M.cleanup() + end, +}) + +return M diff --git a/lua/codetyper/llm/gemini.lua b/lua/codetyper/llm/gemini.lua new file mode 100644 index 0000000..da8a028 --- /dev/null +++ b/lua/codetyper/llm/gemini.lua @@ -0,0 +1,361 @@ +local Utils = require("avante.utils") +local Providers = require("avante.providers") +local Clipboard = require("avante.clipboard") +local OpenAI = require("avante.providers").openai +local Prompts = require("avante.utils.prompts") + +---@class AvanteProviderFunctor +local M = {} + +M.api_key_name = "GEMINI_API_KEY" +M.role_map = { + user = "user", + assistant = "model", +} + +function M:is_disable_stream() + return false +end + +---@param tool AvanteLLMTool +function M:transform_to_function_declaration(tool) + local input_schema_properties, required = Utils.llm_tool_param_fields_to_json_schema(tool.param.fields) + local parameters = nil + if not vim.tbl_isempty(input_schema_properties) then + parameters = { + type = "object", + properties = input_schema_properties, + required = required, + } + end + return { + name = tool.name, + description = tool.get_description and tool.get_description() or tool.description, + parameters = parameters, + } +end + +function M:parse_messages(opts) + local provider_conf, _ = Providers.parse_config(self) + local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true + + local contents = {} + local prev_role = nil + + local tool_id_to_name = {} + vim.iter(opts.messages):each(function(message) + local role = message.role + if role == prev_role then + if role == M.role_map["user"] then + table.insert( + contents, + { role = M.role_map["assistant"], parts = { + { text = "Ok, I understand." }, + } } + ) + else + table.insert(contents, { role = M.role_map["user"], parts = { + { text = "Ok" }, + } }) + end + end + prev_role = role + local parts = {} + local content_items = message.content + if type(content_items) == "string" then + table.insert(parts, { text = content_items }) + elseif type(content_items) == "table" then + ---@cast content_items AvanteLLMMessageContentItem[] + for _, item in ipairs(content_items) do + if type(item) == "string" then + table.insert(parts, { text = item }) + elseif type(item) == "table" and item.type == "text" then + table.insert(parts, { text = item.text }) + elseif type(item) == "table" and item.type == "image" then + table.insert(parts, { + inline_data = { + mime_type = "image/png", + data = item.source.data, + }, + }) + elseif type(item) == "table" and item.type == "tool_use" and not use_ReAct_prompt then + tool_id_to_name[item.id] = item.name + role = "model" + table.insert(parts, { + functionCall = { + name = item.name, + args = item.input, + }, + }) + elseif type(item) == "table" and item.type == "tool_result" and not use_ReAct_prompt then + role = "function" + local ok, content = pcall(vim.json.decode, item.content) + if not ok then + content = item.content + end + -- item.name here refers to the name of the tool that was called, + -- which is available in the tool_result content item prepared by llm.lua + local tool_name = item.name + if not tool_name then + -- Fallback, though item.name should ideally always be present for tool_result + tool_name = tool_id_to_name[item.tool_use_id] + end + table.insert(parts, { + functionResponse = { + name = tool_name, + response = { + name = tool_name, -- Gemini API requires the name in the response object as well + content = content, + }, + }, + }) + elseif type(item) == "table" and item.type == "thinking" then + table.insert(parts, { text = item.thinking }) + elseif type(item) == "table" and item.type == "redacted_thinking" then + table.insert(parts, { text = item.data }) + end + end + if not provider_conf.disable_tools and use_ReAct_prompt then + if content_items[1].type == "tool_result" then + local tool_use_msg = nil + for _, msg_ in ipairs(opts.messages) do + if type(msg_.content) == "table" and #msg_.content > 0 then + if + msg_.content[1].type == "tool_use" + and msg_.content[1].id == content_items[1].tool_use_id + then + tool_use_msg = msg_ + break + end + end + end + if tool_use_msg then + table.insert(contents, { + role = "model", + parts = { + { text = Utils.tool_use_to_xml(tool_use_msg.content[1]) }, + }, + }) + role = "user" + table.insert(parts, { + text = "The result of tool use " + .. Utils.tool_use_to_xml(tool_use_msg.content[1]) + .. " is:\n", + }) + table.insert(parts, { + text = content_items[1].content, + }) + end + end + end + end + if #parts > 0 then + table.insert(contents, { role = M.role_map[role] or role, parts = parts }) + end + end) + + if Clipboard.support_paste_image() and opts.image_paths then + for _, image_path in ipairs(opts.image_paths) do + local image_data = { + inline_data = { + mime_type = "image/png", + data = Clipboard.get_base64_content(image_path), + }, + } + + table.insert(contents[#contents].parts, image_data) + end + end + + local system_prompt = opts.system_prompt + + if use_ReAct_prompt then + system_prompt = Prompts.get_ReAct_system_prompt(provider_conf, opts) + end + + return { + systemInstruction = { + role = "user", + parts = { + { + text = system_prompt, + }, + }, + }, + contents = contents, + } +end + +--- Prepares the main request body for Gemini-like APIs. +---@param provider_instance AvanteProviderFunctor The provider instance (self). +---@param prompt_opts AvantePromptOptions Prompt options including messages, tools, system_prompt. +---@param provider_conf table Provider configuration from config.lua (e.g., model, top-level temperature/max_tokens). +---@param request_body_ table Request-specific overrides, typically from provider_conf.request_config_overrides. +---@return table The fully constructed request body. +function M.prepare_request_body(provider_instance, prompt_opts, provider_conf, request_body_) + local request_body = {} + request_body.generationConfig = request_body_.generationConfig or {} + + local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true + + if use_ReAct_prompt then + request_body.generationConfig.stopSequences = { "" } + end + + local disable_tools = provider_conf.disable_tools or false + + if not use_ReAct_prompt and not disable_tools and prompt_opts.tools then + local function_declarations = {} + for _, tool in ipairs(prompt_opts.tools) do + table.insert(function_declarations, provider_instance:transform_to_function_declaration(tool)) + end + + if #function_declarations > 0 then + request_body.tools = { + { + functionDeclarations = function_declarations, + }, + } + end + end + + return vim.tbl_deep_extend("force", {}, provider_instance:parse_messages(prompt_opts), request_body) +end + +---@param usage avante.GeminiTokenUsage | nil +---@return avante.LLMTokenUsage | nil +function M.transform_gemini_usage(usage) + if not usage then + return nil + end + ---@type avante.LLMTokenUsage + local res = { + prompt_tokens = usage.promptTokenCount, + completion_tokens = usage.candidatesTokenCount, + } + return res +end + +function M:parse_response(ctx, data_stream, _, opts) + local ok, jsn = pcall(vim.json.decode, data_stream) + if not ok then + opts.on_stop({ reason = "error", error = "Failed to parse JSON response: " .. tostring(jsn) }) + return + end + + if opts.update_tokens_usage and jsn.usageMetadata and jsn.usageMetadata ~= nil then + local usage = M.transform_gemini_usage(jsn.usageMetadata) + if usage ~= nil then + opts.update_tokens_usage(usage) + end + end + + -- Handle prompt feedback first, as it might indicate an overall issue with the prompt + if jsn.promptFeedback and jsn.promptFeedback.blockReason then + local feedback = jsn.promptFeedback + OpenAI:finish_pending_messages(ctx, opts) -- Ensure any pending messages are cleared + opts.on_stop({ + reason = "error", + error = "Prompt blocked or filtered. Reason: " .. feedback.blockReason, + details = feedback, + }) + return + end + + if jsn.candidates and #jsn.candidates > 0 then + local candidate = jsn.candidates[1] + ---@type AvanteLLMToolUse[] + ctx.tool_use_list = ctx.tool_use_list or {} + + -- Check if candidate.content and candidate.content.parts exist before iterating + if candidate.content and candidate.content.parts then + for _, part in ipairs(candidate.content.parts) do + if part.text then + if opts.on_chunk then + opts.on_chunk(part.text) + end + OpenAI:add_text_message(ctx, part.text, "generating", opts) + elseif part.functionCall then + if not ctx.function_call_id then + ctx.function_call_id = 0 + end + ctx.function_call_id = ctx.function_call_id + 1 + local tool_use = { + id = ctx.turn_id .. "-" .. tostring(ctx.function_call_id), + name = part.functionCall.name, + input_json = vim.json.encode(part.functionCall.args), + } + table.insert(ctx.tool_use_list, tool_use) + OpenAI:add_tool_use_message(ctx, tool_use, "generated", opts) + end + end + end + + -- Check for finishReason to determine if this candidate's stream is done. + if candidate.finishReason then + OpenAI:finish_pending_messages(ctx, opts) + local reason_str = candidate.finishReason + local stop_details = { finish_reason = reason_str } + stop_details.usage = M.transform_gemini_usage(jsn.usageMetadata) + + if reason_str == "TOOL_CODE" then + -- Model indicates a tool-related stop. + -- The tool_use list is added to the table in llm.lua + opts.on_stop(vim.tbl_deep_extend("force", { reason = "tool_use" }, stop_details)) + elseif reason_str == "STOP" then + if ctx.tool_use_list and #ctx.tool_use_list > 0 then + -- Natural stop, but tools were found in this final chunk. + opts.on_stop(vim.tbl_deep_extend("force", { reason = "tool_use" }, stop_details)) + else + -- Natural stop, no tools in this final chunk. + -- llm.lua will check its accumulated tools if tool_choice was active. + opts.on_stop(vim.tbl_deep_extend("force", { reason = "complete" }, stop_details)) + end + elseif reason_str == "MAX_TOKENS" then + opts.on_stop(vim.tbl_deep_extend("force", { reason = "max_tokens" }, stop_details)) + elseif reason_str == "SAFETY" or reason_str == "RECITATION" then + opts.on_stop( + vim.tbl_deep_extend( + "force", + { reason = "error", error = "Generation stopped: " .. reason_str }, + stop_details + ) + ) + else -- OTHER, FINISH_REASON_UNSPECIFIED, or any other unhandled reason. + opts.on_stop( + vim.tbl_deep_extend( + "force", + { reason = "error", error = "Generation stopped with unhandled reason: " .. reason_str }, + stop_details + ) + ) + end + end + -- If no finishReason, it's an intermediate chunk; do not call on_stop. + end +end + +---@param prompt_opts AvantePromptOptions +---@return AvanteCurlOutput|nil +function M:parse_curl_args(prompt_opts) + local provider_conf, request_body = Providers.parse_config(self) + + local api_key = self:parse_api_key() + if api_key == nil then + Utils.error("Gemini: API key is not set. Please set " .. M.api_key_name) + return nil + end + + return { + url = Utils.url_join( + provider_conf.endpoint, + provider_conf.model .. ":streamGenerateContent?alt=sse&key=" .. api_key + ), + proxy = provider_conf.proxy, + insecure = provider_conf.allow_insecure, + headers = Utils.tbl_override({ ["Content-Type"] = "application/json" }, self.extra_headers), + body = M.prepare_request_body(self, prompt_opts, provider_conf, request_body), + } +end + +return M diff --git a/lua/codetyper/llm/ollama.lua b/lua/codetyper/llm/ollama.lua index b0defdd..b25bd91 100644 --- a/lua/codetyper/llm/ollama.lua +++ b/lua/codetyper/llm/ollama.lua @@ -8,19 +8,19 @@ local llm = require("codetyper.llm") --- Get Ollama host from config ---@return string Host URL local function get_host() - local codetyper = require("codetyper") - local config = codetyper.get_config() + local codetyper = require("codetyper") + local config = codetyper.get_config() - return config.llm.ollama.host + return config.llm.ollama.host end --- Get model from config ---@return string Model name local function get_model() - local codetyper = require("codetyper") - local config = codetyper.get_config() + local codetyper = require("codetyper") + local config = codetyper.get_config() - return config.llm.ollama.model + return config.llm.ollama.model end --- Build request body for Ollama API @@ -28,93 +28,96 @@ end ---@param context table Context information ---@return table Request body local function build_request_body(prompt, context) - local system_prompt = llm.build_system_prompt(context) + local system_prompt = llm.build_system_prompt(context) - return { - model = get_model(), - system = system_prompt, - prompt = prompt, - stream = false, - options = { - temperature = 0.2, - num_predict = 4096, - }, - } + return { + model = get_model(), + system = system_prompt, + prompt = prompt, + stream = false, + options = { + temperature = 0.2, + num_predict = 4096, + }, + } end --- Make HTTP request to Ollama API ---@param body table Request body ---@param callback fun(response: string|nil, error: string|nil, usage: table|nil) Callback function local function make_request(body, callback) - local host = get_host() - local url = host .. "/api/generate" - local json_body = vim.json.encode(body) + local host = get_host() + local url = host .. "/api/generate" + local json_body = vim.json.encode(body) - local cmd = { - "curl", - "-s", - "-X", "POST", - url, - "-H", "Content-Type: application/json", - "-d", json_body, - } + local cmd = { + "curl", + "-s", + "-X", + "POST", + url, + "-H", + "Content-Type: application/json", + "-d", + json_body, + } - vim.fn.jobstart(cmd, { - stdout_buffered = true, - on_stdout = function(_, data) - if not data or #data == 0 or (data[1] == "" and #data == 1) then - return - end + vim.fn.jobstart(cmd, { + stdout_buffered = true, + on_stdout = function(_, data) + if not data or #data == 0 or (data[1] == "" and #data == 1) then + return + end - local response_text = table.concat(data, "\n") - local ok, response = pcall(vim.json.decode, response_text) + local response_text = table.concat(data, "\n") + local ok, response = pcall(vim.json.decode, response_text) - if not ok then - vim.schedule(function() - callback(nil, "Failed to parse Ollama response", nil) - end) - return - end + if not ok then + vim.schedule(function() + callback(nil, "Failed to parse Ollama response", nil) + end) + return + end - if response.error then - vim.schedule(function() - callback(nil, response.error or "Ollama API error", nil) - end) - return - end + if response.error then + vim.schedule(function() + callback(nil, response.error or "Ollama API error", nil) + end) + return + end - -- Extract usage info - local usage = { - prompt_tokens = response.prompt_eval_count or 0, - response_tokens = response.eval_count or 0, - } + -- Extract usage info + local usage = { + prompt_tokens = response.prompt_eval_count or 0, + response_tokens = response.eval_count or 0, + } - if response.response then - local code = llm.extract_code(response.response) - vim.schedule(function() - callback(code, nil, usage) - end) - else - vim.schedule(function() - callback(nil, "No response from Ollama", nil) - end) - end - end, - on_stderr = function(_, data) - if data and #data > 0 and data[1] ~= "" then - vim.schedule(function() - callback(nil, "Ollama API request failed: " .. table.concat(data, "\n"), nil) - end) - end - end, - on_exit = function(_, code) - if code ~= 0 then - vim.schedule(function() - callback(nil, "Ollama API request failed with code: " .. code, nil) - end) - end - end, - }) + if response.response then + local code = llm.extract_code(response.response) + vim.schedule(function() + callback(code, nil, usage) + end) + else + vim.schedule(function() + callback(nil, "No response from Ollama", nil) + end) + end + end, + on_stderr = function(_, data) + if data and #data > 0 and data[1] ~= "" then + vim.schedule(function() + callback(nil, "Ollama API request failed: " .. table.concat(data, "\n"), nil) + end) + end + end, + on_exit = function(_, code) + if code ~= 0 then + vim.schedule(function() + callback(nil, "Ollama API request failed with code: " .. code, nil) + end) + end + end, + }) end --- Generate code using Ollama API @@ -122,111 +125,107 @@ end ---@param context table Context information ---@param callback fun(response: string|nil, error: string|nil) Callback function function M.generate(prompt, context, callback) - local logs = require("codetyper.agent.logs") - local model = get_model() + local logs = require("codetyper.agent.logs") + local model = get_model() - -- Log the request - logs.request("ollama", model) - logs.thinking("Building request body...") + -- Log the request + logs.request("ollama", model) + logs.thinking("Building request body...") - local body = build_request_body(prompt, context) + local body = build_request_body(prompt, context) - -- Estimate prompt tokens - local prompt_estimate = logs.estimate_tokens(vim.json.encode(body)) - logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate)) - logs.thinking("Sending to Ollama API...") + -- Estimate prompt tokens + local prompt_estimate = logs.estimate_tokens(vim.json.encode(body)) + logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate)) + logs.thinking("Sending to Ollama API...") - utils.notify("Sending request to Ollama...", vim.log.levels.INFO) + utils.notify("Sending request to Ollama...", vim.log.levels.INFO) - make_request(body, function(response, err, usage) - if err then - logs.error(err) - utils.notify(err, vim.log.levels.ERROR) - callback(nil, err) - else - -- Log token usage - if usage then - logs.response( - usage.prompt_tokens or 0, - usage.response_tokens or 0, - "end_turn" - ) - end - logs.thinking("Response received, extracting code...") - logs.info("Code generated successfully") - utils.notify("Code generated successfully", vim.log.levels.INFO) - callback(response, nil) - end - end) + make_request(body, function(response, err, usage) + if err then + logs.error(err) + utils.notify(err, vim.log.levels.ERROR) + callback(nil, err) + else + -- Log token usage + if usage then + logs.response(usage.prompt_tokens or 0, usage.response_tokens or 0, "end_turn") + end + logs.thinking("Response received, extracting code...") + logs.info("Code generated successfully") + utils.notify("Code generated successfully", vim.log.levels.INFO) + callback(response, nil) + end + end) end --- Check if Ollama is reachable ---@param callback fun(ok: boolean, error: string|nil) Callback function function M.health_check(callback) - local host = get_host() + local host = get_host() - local cmd = { "curl", "-s", host .. "/api/tags" } + local cmd = { "curl", "-s", host .. "/api/tags" } - vim.fn.jobstart(cmd, { - stdout_buffered = true, - on_stdout = function(_, data) - if data and #data > 0 and data[1] ~= "" then - vim.schedule(function() - callback(true, nil) - end) - end - end, - on_exit = function(_, code) - if code ~= 0 then - vim.schedule(function() - callback(false, "Cannot connect to Ollama at " .. host) - end) - end - end, - }) + vim.fn.jobstart(cmd, { + stdout_buffered = true, + on_stdout = function(_, data) + if data and #data > 0 and data[1] ~= "" then + vim.schedule(function() + callback(true, nil) + end) + end + end, + on_exit = function(_, code) + if code ~= 0 then + vim.schedule(function() + callback(false, "Cannot connect to Ollama at " .. host) + end) + end + end, + }) end --- Check if Ollama is properly configured ---@return boolean, string? Valid status and optional error message function M.validate() - local host = get_host() - if not host or host == "" then - return false, "Ollama host not configured" - end - local model = get_model() - if not model or model == "" then - return false, "Ollama model not configured" - end - return true + local host = get_host() + if not host or host == "" then + return false, "Ollama host not configured" + end + local model = get_model() + if not model or model == "" then + return false, "Ollama model not configured" + end + return true end --- Build system prompt for agent mode with tool instructions ---@param context table Context information ---@return string System prompt local function build_agent_system_prompt(context) - local agent_prompts = require("codetyper.prompts.agent") - local tools_module = require("codetyper.agent.tools") + local agent_prompts = require("codetyper.prompts.agent") + local tools_module = require("codetyper.agent.tools") - local system_prompt = agent_prompts.system .. "\n\n" - system_prompt = system_prompt .. tools_module.to_prompt_format() .. "\n\n" - system_prompt = system_prompt .. agent_prompts.tool_instructions + local system_prompt = agent_prompts.system .. "\n\n" + system_prompt = system_prompt .. tools_module.to_prompt_format() .. "\n\n" + system_prompt = system_prompt .. agent_prompts.tool_instructions - -- Add context about current file if available - if context.file_path then - system_prompt = system_prompt .. "\n\nCurrent working context:\n" - system_prompt = system_prompt .. "- File: " .. context.file_path .. "\n" - if context.language then - system_prompt = system_prompt .. "- Language: " .. context.language .. "\n" - end - end + -- Add context about current file if available + if context.file_path then + system_prompt = system_prompt .. "\n\nCurrent working context:\n" + system_prompt = system_prompt .. "- File: " .. context.file_path .. "\n" + if context.language then + system_prompt = system_prompt .. "- Language: " .. context.language .. "\n" + end + end - -- Add project root info - local root = utils.get_project_root() - if root then - system_prompt = system_prompt .. "- Project root: " .. root .. "\n" - end + -- Add project root info + local root = utils.get_project_root() + if root then + system_prompt = system_prompt .. "- Project root: " .. root .. "\n" + end - return system_prompt + return system_prompt end --- Build request body for Ollama API with tools (chat format) @@ -234,114 +233,117 @@ end ---@param context table Context information ---@return table Request body local function build_tools_request_body(messages, context) - local system_prompt = build_agent_system_prompt(context) + local system_prompt = build_agent_system_prompt(context) - -- Convert messages to Ollama chat format - local ollama_messages = {} - for _, msg in ipairs(messages) do - local content = msg.content - -- Handle complex content (like tool results) - if type(content) == "table" then - local text_parts = {} - for _, part in ipairs(content) do - if part.type == "tool_result" then - table.insert(text_parts, "[" .. (part.name or "tool") .. " result]: " .. (part.content or "")) - elseif part.type == "text" then - table.insert(text_parts, part.text or "") - end - end - content = table.concat(text_parts, "\n") - end + -- Convert messages to Ollama chat format + local ollama_messages = {} + for _, msg in ipairs(messages) do + local content = msg.content + -- Handle complex content (like tool results) + if type(content) == "table" then + local text_parts = {} + for _, part in ipairs(content) do + if part.type == "tool_result" then + table.insert(text_parts, "[" .. (part.name or "tool") .. " result]: " .. (part.content or "")) + elseif part.type == "text" then + table.insert(text_parts, part.text or "") + end + end + content = table.concat(text_parts, "\n") + end - table.insert(ollama_messages, { - role = msg.role, - content = content, - }) - end + table.insert(ollama_messages, { + role = msg.role, + content = content, + }) + end - return { - model = get_model(), - messages = ollama_messages, - system = system_prompt, - stream = false, - options = { - temperature = 0.3, - num_predict = 4096, - }, - } + return { + model = get_model(), + messages = ollama_messages, + system = system_prompt, + stream = false, + options = { + temperature = 0.3, + num_predict = 4096, + }, + } end --- Make HTTP request to Ollama chat API ---@param body table Request body ---@param callback fun(response: string|nil, error: string|nil, usage: table|nil) Callback function local function make_chat_request(body, callback) - local host = get_host() - local url = host .. "/api/chat" - local json_body = vim.json.encode(body) + local host = get_host() + local url = host .. "/api/chat" + local json_body = vim.json.encode(body) - local cmd = { - "curl", - "-s", - "-X", "POST", - url, - "-H", "Content-Type: application/json", - "-d", json_body, - } + local cmd = { + "curl", + "-s", + "-X", + "POST", + url, + "-H", + "Content-Type: application/json", + "-d", + json_body, + } - vim.fn.jobstart(cmd, { - stdout_buffered = true, - on_stdout = function(_, data) - if not data or #data == 0 or (data[1] == "" and #data == 1) then - return - end + vim.fn.jobstart(cmd, { + stdout_buffered = true, + on_stdout = function(_, data) + if not data or #data == 0 or (data[1] == "" and #data == 1) then + return + end - local response_text = table.concat(data, "\n") - local ok, response = pcall(vim.json.decode, response_text) + local response_text = table.concat(data, "\n") + local ok, response = pcall(vim.json.decode, response_text) - if not ok then - vim.schedule(function() - callback(nil, "Failed to parse Ollama response", nil) - end) - return - end + if not ok then + vim.schedule(function() + callback(nil, "Failed to parse Ollama response", nil) + end) + return + end - if response.error then - vim.schedule(function() - callback(nil, response.error or "Ollama API error", nil) - end) - return - end + if response.error then + vim.schedule(function() + callback(nil, response.error or "Ollama API error", nil) + end) + return + end - -- Extract usage info - local usage = { - prompt_tokens = response.prompt_eval_count or 0, - response_tokens = response.eval_count or 0, - } + -- Extract usage info + local usage = { + prompt_tokens = response.prompt_eval_count or 0, + response_tokens = response.eval_count or 0, + } - -- Return the message content for agent parsing - if response.message and response.message.content then - vim.schedule(function() - callback(response.message.content, nil, usage) - end) - else - vim.schedule(function() - callback(nil, "No response from Ollama", nil) - end) - end - end, - on_stderr = function(_, data) - if data and #data > 0 and data[1] ~= "" then - vim.schedule(function() - callback(nil, "Ollama API request failed: " .. table.concat(data, "\n"), nil) - end) - end - end, - on_exit = function(_, code) - if code ~= 0 then - -- Don't double-report errors - end - end, - }) + -- Return the message content for agent parsing + if response.message and response.message.content then + vim.schedule(function() + callback(response.message.content, nil, usage) + end) + else + vim.schedule(function() + callback(nil, "No response from Ollama", nil) + end) + end + end, + on_stderr = function(_, data) + if data and #data > 0 and data[1] ~= "" then + vim.schedule(function() + callback(nil, "Ollama API request failed: " .. table.concat(data, "\n"), nil) + end) + end + end, + on_exit = function(_, code) + if code ~= 0 then + -- Don't double-report errors + end + end, + }) end --- Generate response with tools using Ollama API @@ -350,50 +352,46 @@ end ---@param tools table Tool definitions (embedded in prompt for Ollama) ---@param callback fun(response: string|nil, error: string|nil) Callback function function M.generate_with_tools(messages, context, tools, callback) - local logs = require("codetyper.agent.logs") + local logs = require("codetyper.agent.logs") - -- Log the request - local model = get_model() - logs.request("ollama", model) - logs.thinking("Preparing API request...") + -- Log the request + local model = get_model() + logs.request("ollama", model) + logs.thinking("Preparing API request...") - local body = build_tools_request_body(messages, context) + local body = build_tools_request_body(messages, context) - -- Estimate prompt tokens - local prompt_estimate = logs.estimate_tokens(vim.json.encode(body)) - logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate)) + -- Estimate prompt tokens + local prompt_estimate = logs.estimate_tokens(vim.json.encode(body)) + logs.debug(string.format("Estimated prompt: ~%d tokens", prompt_estimate)) - make_chat_request(body, function(response, err, usage) - if err then - logs.error(err) - callback(nil, err) - else - -- Log token usage - if usage then - logs.response( - usage.prompt_tokens or 0, - usage.response_tokens or 0, - "end_turn" - ) - end + make_chat_request(body, function(response, err, usage) + if err then + logs.error(err) + callback(nil, err) + else + -- Log token usage + if usage then + logs.response(usage.prompt_tokens or 0, usage.response_tokens or 0, "end_turn") + end - -- Log if response contains tool calls - if response then - local parser = require("codetyper.agent.parser") - local parsed = parser.parse_ollama_response(response) - if #parsed.tool_calls > 0 then - for _, tc in ipairs(parsed.tool_calls) do - logs.thinking("Tool call: " .. tc.name) - end - end - if parsed.text and parsed.text ~= "" then - logs.thinking("Response contains text") - end - end + -- Log if response contains tool calls + if response then + local parser = require("codetyper.agent.parser") + local parsed = parser.parse_ollama_response(response) + if #parsed.tool_calls > 0 then + for _, tc in ipairs(parsed.tool_calls) do + logs.thinking("Tool call: " .. tc.name) + end + end + if parsed.text and parsed.text ~= "" then + logs.thinking("Response contains text") + end + end - callback(response, nil) - end - end) + callback(response, nil) + end + end) end return M diff --git a/lua/codetyper/llm/openai.lua b/lua/codetyper/llm/openai.lua new file mode 100644 index 0000000..b6dbdb8 --- /dev/null +++ b/lua/codetyper/llm/openai.lua @@ -0,0 +1,973 @@ +local Utils = require("avante.utils") +local Config = require("avante.config") +local Clipboard = require("avante.clipboard") +local Providers = require("avante.providers") +local HistoryMessage = require("avante.history.message") +local ReActParser = require("avante.libs.ReAct_parser2") +local JsonParser = require("avante.libs.jsonparser") +local Prompts = require("avante.utils.prompts") +local LlmTools = require("avante.llm_tools") + +---@class AvanteProviderFunctor +local M = {} + +M.api_key_name = "OPENAI_API_KEY" + +M.role_map = { + user = "user", + assistant = "assistant", +} + +function M:is_disable_stream() + return false +end + +---@param tool AvanteLLMTool +---@return AvanteOpenAITool +function M:transform_tool(tool) + local input_schema_properties, required = Utils.llm_tool_param_fields_to_json_schema(tool.param.fields) + ---@type AvanteOpenAIToolFunctionParameters + local parameters = { + type = "object", + properties = input_schema_properties, + required = required, + additionalProperties = false, + } + ---@type AvanteOpenAITool + local res = { + type = "function", + ["function"] = { + name = tool.name, + description = tool.get_description and tool.get_description() or tool.description, + parameters = parameters, + }, + } + return res +end + +function M.is_openrouter(url) + return url:match("^https://openrouter%.ai/") +end + +function M.is_mistral(url) + return url:match("^https://api%.mistral%.ai/") +end + +---@param opts AvantePromptOptions +function M.get_user_message(opts) + vim.deprecate("get_user_message", "parse_messages", "0.1.0", "avante.nvim") + return table.concat( + vim.iter(opts.messages) + :filter(function(_, value) + return value == nil or value.role ~= "user" + end) + :fold({}, function(acc, value) + acc = vim.list_extend({}, acc) + acc = vim.list_extend(acc, { value.content }) + return acc + end), + "\n" + ) +end + +function M.is_reasoning_model(model) + return model + and (string.match(model, "^o%d+") ~= nil or (string.match(model, "gpt%-5") ~= nil and model ~= "gpt-5-chat")) +end + +function M.set_allowed_params(provider_conf, request_body) + local use_response_api = Providers.resolve_use_response_api(provider_conf, nil) + if M.is_reasoning_model(provider_conf.model) then + -- Reasoning models have specific parameter requirements + request_body.temperature = 1 + -- Response API doesn't support temperature for reasoning models + if use_response_api then + request_body.temperature = nil + end + else + request_body.reasoning_effort = nil + request_body.reasoning = nil + end + -- If max_tokens is set in config, unset max_completion_tokens + if request_body.max_tokens then + request_body.max_completion_tokens = nil + end + + -- Handle Response API specific parameters + if use_response_api then + -- Convert reasoning_effort to reasoning object for Response API + if request_body.reasoning_effort then + request_body.reasoning = { + effort = request_body.reasoning_effort, + } + request_body.reasoning_effort = nil + end + + -- Response API doesn't support some parameters + -- Remove unsupported parameters for Response API + local unsupported_params = { + "top_p", + "frequency_penalty", + "presence_penalty", + "logit_bias", + "logprobs", + "top_logprobs", + "n", + } + for _, param in ipairs(unsupported_params) do + request_body[param] = nil + end + end +end + +function M:parse_messages(opts) + local messages = {} + local provider_conf, _ = Providers.parse_config(self) + local use_response_api = Providers.resolve_use_response_api(provider_conf, opts) + + local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true + local system_prompt = opts.system_prompt + + if use_ReAct_prompt then + system_prompt = Prompts.get_ReAct_system_prompt(provider_conf, opts) + end + + if self.is_reasoning_model(provider_conf.model) then + table.insert(messages, { role = "developer", content = system_prompt }) + else + table.insert(messages, { role = "system", content = system_prompt }) + end + + local has_tool_use = false + + vim.iter(opts.messages):each(function(msg) + if type(msg.content) == "string" then + table.insert(messages, { role = self.role_map[msg.role], content = msg.content }) + elseif type(msg.content) == "table" then + -- Check if this is a reasoning message (object with type "reasoning") + if msg.content.type == "reasoning" then + -- Add reasoning message directly (for Response API) + table.insert(messages, { + type = "reasoning", + id = msg.content.id, + encrypted_content = msg.content.encrypted_content, + summary = msg.content.summary, + }) + return + end + + local content = {} + local tool_calls = {} + local tool_results = {} + for _, item in ipairs(msg.content) do + if type(item) == "string" then + table.insert(content, { type = "text", text = item }) + elseif item.type == "text" then + table.insert(content, { type = "text", text = item.text }) + elseif item.type == "image" then + table.insert(content, { + type = "image_url", + image_url = { + url = "data:" + .. item.source.media_type + .. ";" + .. item.source.type + .. "," + .. item.source.data, + }, + }) + elseif item.type == "reasoning" then + -- Add reasoning message directly (for Response API) + table.insert(messages, { + type = "reasoning", + id = item.id, + encrypted_content = item.encrypted_content, + summary = item.summary, + }) + elseif item.type == "tool_use" and not use_ReAct_prompt then + has_tool_use = true + table.insert(tool_calls, { + id = item.id, + type = "function", + ["function"] = { name = item.name, arguments = vim.json.encode(item.input) }, + }) + elseif item.type == "tool_result" and has_tool_use and not use_ReAct_prompt then + table.insert( + tool_results, + { + tool_call_id = item.tool_use_id, + content = item.is_error and "Error: " .. item.content or item.content, + } + ) + end + end + if not provider_conf.disable_tools and use_ReAct_prompt then + if msg.content[1].type == "tool_result" then + local tool_use_msg = nil + for _, msg_ in ipairs(opts.messages) do + if type(msg_.content) == "table" and #msg_.content > 0 then + if + msg_.content[1].type == "tool_use" + and msg_.content[1].id == msg.content[1].tool_use_id + then + tool_use_msg = msg_ + break + end + end + end + if tool_use_msg then + msg.role = "user" + table.insert(content, { + type = "text", + text = "The result of tool use " + .. Utils.tool_use_to_xml(tool_use_msg.content[1]) + .. " is:\n", + }) + table.insert(content, { + type = "text", + text = msg.content[1].content, + }) + end + end + end + if #content > 0 then + table.insert(messages, { role = self.role_map[msg.role], content = content }) + end + if not provider_conf.disable_tools and not use_ReAct_prompt then + if #tool_calls > 0 then + -- Only skip tool_calls if using Response API with previous_response_id support + -- Copilot uses Response API format but doesn't support previous_response_id + local should_include_tool_calls = not use_response_api + or not provider_conf.support_previous_response_id + + if should_include_tool_calls then + -- For Response API without previous_response_id support (like Copilot), + -- convert tool_calls to function_call items in input + if use_response_api then + for _, tool_call in ipairs(tool_calls) do + table.insert(messages, { + type = "function_call", + call_id = tool_call.id, + name = tool_call["function"].name, + arguments = tool_call["function"].arguments, + }) + end + else + -- Chat Completions API format + local last_message = messages[#messages] + if + last_message + and last_message.role == self.role_map["assistant"] + and last_message.tool_calls + then + last_message.tool_calls = vim.list_extend(last_message.tool_calls, tool_calls) + + if not last_message.content then + last_message.content = "" + end + else + table.insert( + messages, + { role = self.role_map["assistant"], tool_calls = tool_calls, content = "" } + ) + end + end + end + -- If support_previous_response_id is true, Response API manages function call history + -- So we can skip adding tool_calls to input messages + end + if #tool_results > 0 then + for _, tool_result in ipairs(tool_results) do + -- Response API uses different format for function outputs + if use_response_api then + table.insert(messages, { + type = "function_call_output", + call_id = tool_result.tool_call_id, + output = tool_result.content or "", + }) + else + table.insert( + messages, + { + role = "tool", + tool_call_id = tool_result.tool_call_id, + content = tool_result.content or "", + } + ) + end + end + end + end + end + end) + + if Config.behaviour.support_paste_from_clipboard and opts.image_paths and #opts.image_paths > 0 then + local message_content = messages[#messages].content + if type(message_content) ~= "table" or message_content[1] == nil then + message_content = { { type = "text", text = message_content } } + end + for _, image_path in ipairs(opts.image_paths) do + table.insert(message_content, { + type = "image_url", + image_url = { + url = "data:image/png;base64," .. Clipboard.get_base64_content(image_path), + }, + }) + end + messages[#messages].content = message_content + end + + local final_messages = {} + local prev_role = nil + local prev_type = nil + + vim.iter(messages):each(function(message) + local role = message.role + if + role == prev_role + and role ~= "tool" + and prev_type ~= "function_call" + and prev_type ~= "function_call_output" + then + if role == self.role_map["assistant"] then + table.insert(final_messages, { role = self.role_map["user"], content = "Ok" }) + else + table.insert(final_messages, { role = self.role_map["assistant"], content = "Ok, I understand." }) + end + else + if role == "user" and prev_role == "tool" and M.is_mistral(provider_conf.endpoint) then + table.insert(final_messages, { role = self.role_map["assistant"], content = "Ok, I understand." }) + end + end + prev_role = role + prev_type = message.type + table.insert(final_messages, message) + end) + + return final_messages +end + +function M:finish_pending_messages(ctx, opts) + if ctx.content ~= nil and ctx.content ~= "" then + self:add_text_message(ctx, "", "generated", opts) + end + if ctx.tool_use_map then + for _, tool_use in pairs(ctx.tool_use_map) do + if tool_use.state == "generating" then + self:add_tool_use_message(ctx, tool_use, "generated", opts) + end + end + end +end + +local llm_tool_names = nil + +function M:add_text_message(ctx, text, state, opts) + if llm_tool_names == nil then + llm_tool_names = LlmTools.get_tool_names() + end + if ctx.content == nil then + ctx.content = "" + end + ctx.content = ctx.content .. text + local content = + ctx.content:gsub("", ""):gsub("", ""):gsub("", ""):gsub("", "") + ctx.content = content + local msg = HistoryMessage:new("assistant", ctx.content, { + state = state, + uuid = ctx.content_uuid, + original_content = ctx.content, + }) + ctx.content_uuid = msg.uuid + local msgs = { msg } + local xml_content = ctx.content + local xml_lines = vim.split(xml_content, "\n") + local cleaned_xml_lines = {} + local prev_tool_name = nil + for _, line in ipairs(xml_lines) do + if line:match("") then + local tool_name = line:match("(.*)") + if tool_name then + prev_tool_name = tool_name + end + elseif line:match("") then + if prev_tool_name then + table.insert(cleaned_xml_lines, "<" .. prev_tool_name .. ">") + end + goto continue + elseif line:match("") then + if prev_tool_name then + table.insert(cleaned_xml_lines, "") + end + goto continue + end + table.insert(cleaned_xml_lines, line) + ::continue:: + end + local cleaned_xml_content = table.concat(cleaned_xml_lines, "\n") + local xml = ReActParser.parse(cleaned_xml_content) + if xml and #xml > 0 then + local new_content_list = {} + local xml_md_openned = false + for idx, item in ipairs(xml) do + if item.type == "text" then + local cleaned_lines = {} + local lines = vim.split(item.text, "\n") + for _, line in ipairs(lines) do + if line:match("^```xml") or line:match("^```tool_code") or line:match("^```tool_use") then + xml_md_openned = true + elseif line:match("^```$") then + if xml_md_openned then + xml_md_openned = false + else + table.insert(cleaned_lines, line) + end + else + table.insert(cleaned_lines, line) + end + end + table.insert(new_content_list, table.concat(cleaned_lines, "\n")) + goto continue + end + if not vim.tbl_contains(llm_tool_names, item.tool_name) then + goto continue + end + local input = {} + for k, v in pairs(item.tool_input or {}) do + local ok, jsn = pcall(vim.json.decode, v) + if ok and jsn then + input[k] = jsn + else + input[k] = v + end + end + if next(input) ~= nil then + local msg_uuid = ctx.content_uuid .. "-" .. idx + local tool_use_id = msg_uuid + local tool_message_state = item.partial and "generating" or "generated" + local msg_ = HistoryMessage:new("assistant", { + type = "tool_use", + name = item.tool_name, + id = tool_use_id, + input = input, + }, { + state = tool_message_state, + uuid = msg_uuid, + turn_id = ctx.turn_id, + }) + msgs[#msgs + 1] = msg_ + ctx.tool_use_map = ctx.tool_use_map or {} + local input_json = type(input) == "string" and input or vim.json.encode(input) + local exists = false + for _, tool_use in pairs(ctx.tool_use_map) do + if tool_use.id == tool_use_id then + tool_use.input_json = input_json + exists = true + end + end + if not exists then + local tool_key = tostring(vim.tbl_count(ctx.tool_use_map)) + ctx.tool_use_map[tool_key] = { + uuid = tool_use_id, + id = tool_use_id, + name = item.tool_name, + input_json = input_json, + state = "generating", + } + end + opts.on_stop({ reason = "tool_use", streaming_tool_use = item.partial }) + end + ::continue:: + end + msg.message.content = table.concat(new_content_list, "\n"):gsub("\n+$", "\n") + end + if opts.on_messages_add then + opts.on_messages_add(msgs) + end +end + +function M:add_thinking_message(ctx, text, state, opts) + if ctx.reasonging_content == nil then + ctx.reasonging_content = "" + end + ctx.reasonging_content = ctx.reasonging_content .. text + local msg = HistoryMessage:new("assistant", { + type = "thinking", + thinking = ctx.reasonging_content, + signature = "", + }, { + state = state, + uuid = ctx.reasonging_content_uuid, + turn_id = ctx.turn_id, + }) + ctx.reasonging_content_uuid = msg.uuid + if opts.on_messages_add then + opts.on_messages_add({ msg }) + end +end + +function M:add_tool_use_message(ctx, tool_use, state, opts) + local jsn = JsonParser.parse(tool_use.input_json) + local msg = HistoryMessage:new("assistant", { + type = "tool_use", + name = tool_use.name, + id = tool_use.id, + input = jsn or {}, + }, { + state = state, + uuid = tool_use.uuid, + turn_id = ctx.turn_id, + }) + tool_use.uuid = msg.uuid + tool_use.state = state + if opts.on_messages_add then + opts.on_messages_add({ msg }) + end + if state == "generating" then + opts.on_stop({ reason = "tool_use", streaming_tool_use = true }) + end +end + +function M:add_reasoning_message(ctx, reasoning_item, opts) + local msg = HistoryMessage:new("assistant", { + type = "reasoning", + id = reasoning_item.id, + encrypted_content = reasoning_item.encrypted_content, + summary = reasoning_item.summary, + }, { + state = "generated", + uuid = Utils.uuid(), + turn_id = ctx.turn_id, + }) + if opts.on_messages_add then + opts.on_messages_add({ msg }) + end +end + +---@param usage avante.OpenAITokenUsage | nil +---@return avante.LLMTokenUsage | nil +function M.transform_openai_usage(usage) + if not usage then + return nil + end + if usage == vim.NIL then + return nil + end + ---@type avante.LLMTokenUsage + local res = { + prompt_tokens = usage.prompt_tokens, + completion_tokens = usage.completion_tokens, + } + return res +end + +function M:parse_response(ctx, data_stream, _, opts) + if data_stream:match('"%[DONE%]":') or data_stream == "[DONE]" then + self:finish_pending_messages(ctx, opts) + if ctx.tool_use_map and vim.tbl_count(ctx.tool_use_map) > 0 then + ctx.tool_use_map = {} + opts.on_stop({ reason = "tool_use" }) + else + opts.on_stop({ reason = "complete" }) + end + return + end + + local jsn = vim.json.decode(data_stream) + + -- Check if this is a Response API event (has 'type' field) + if jsn.type and type(jsn.type) == "string" then + -- Response API event-driven format + if jsn.type == "response.output_text.delta" then + -- Text content delta + if jsn.delta and jsn.delta ~= vim.NIL and jsn.delta ~= "" then + if opts.on_chunk then + opts.on_chunk(jsn.delta) + end + self:add_text_message(ctx, jsn.delta, "generating", opts) + end + elseif jsn.type == "response.reasoning_summary_text.delta" then + -- Reasoning summary delta + if jsn.delta and jsn.delta ~= vim.NIL and jsn.delta ~= "" then + if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then + ctx.returned_think_start_tag = true + if opts.on_chunk then + opts.on_chunk("\n") + end + end + ctx.last_think_content = jsn.delta + self:add_thinking_message(ctx, jsn.delta, "generating", opts) + if opts.on_chunk then + opts.on_chunk(jsn.delta) + end + end + elseif jsn.type == "response.function_call_arguments.delta" then + -- Function call arguments delta + if jsn.delta and jsn.delta ~= vim.NIL and jsn.delta ~= "" then + if not ctx.tool_use_map then + ctx.tool_use_map = {} + end + local tool_key = tostring(jsn.output_index or 0) + if not ctx.tool_use_map[tool_key] then + ctx.tool_use_map[tool_key] = { + name = jsn.name or "", + id = jsn.call_id or "", + input_json = jsn.delta, + } + else + ctx.tool_use_map[tool_key].input_json = ctx.tool_use_map[tool_key].input_json .. jsn.delta + end + end + elseif jsn.type == "response.output_item.added" then + -- Output item added (could be function call or reasoning) + if jsn.item and jsn.item.type == "function_call" then + local tool_key = tostring(jsn.output_index or 0) + if not ctx.tool_use_map then + ctx.tool_use_map = {} + end + ctx.tool_use_map[tool_key] = { + name = jsn.item.name or "", + id = jsn.item.call_id or jsn.item.id or "", + input_json = "", + } + self:add_tool_use_message(ctx, ctx.tool_use_map[tool_key], "generating", opts) + elseif jsn.item and jsn.item.type == "reasoning" then + -- Add reasoning item to history + self:add_reasoning_message(ctx, jsn.item, opts) + end + elseif jsn.type == "response.output_item.done" then + -- Output item done (finalize function call) + if jsn.item and jsn.item.type == "function_call" then + local tool_key = tostring(jsn.output_index or 0) + if ctx.tool_use_map and ctx.tool_use_map[tool_key] then + local tool_use = ctx.tool_use_map[tool_key] + if jsn.item.arguments then + tool_use.input_json = jsn.item.arguments + end + self:add_tool_use_message(ctx, tool_use, "generated", opts) + end + end + elseif jsn.type == "response.completed" or jsn.type == "response.done" then + -- Response completed - save response.id for future requests + if jsn.response and jsn.response.id then + ctx.last_response_id = jsn.response.id + -- Store in provider for next request + self.last_response_id = jsn.response.id + end + if + ctx.returned_think_start_tag ~= nil + and (ctx.returned_think_end_tag == nil or not ctx.returned_think_end_tag) + then + ctx.returned_think_end_tag = true + if opts.on_chunk then + if + ctx.last_think_content + and ctx.last_think_content ~= vim.NIL + and ctx.last_think_content:sub(-1) ~= "\n" + then + opts.on_chunk("\n\n") + else + opts.on_chunk("\n") + end + end + self:add_thinking_message(ctx, "", "generated", opts) + end + self:finish_pending_messages(ctx, opts) + local usage = nil + if jsn.response and jsn.response.usage then + usage = self.transform_openai_usage(jsn.response.usage) + end + if ctx.tool_use_map and vim.tbl_count(ctx.tool_use_map) > 0 then + opts.on_stop({ reason = "tool_use", usage = usage }) + else + opts.on_stop({ reason = "complete", usage = usage }) + end + elseif jsn.type == "error" then + -- Error event + local error_msg = jsn.error and vim.inspect(jsn.error) or "Unknown error" + opts.on_stop({ reason = "error", error = error_msg }) + end + return + end + + -- Chat Completions API format (original code) + if jsn.usage and jsn.usage ~= vim.NIL then + if opts.update_tokens_usage then + local usage = self.transform_openai_usage(jsn.usage) + if usage then + opts.update_tokens_usage(usage) + end + end + end + if jsn.error and jsn.error ~= vim.NIL then + opts.on_stop({ reason = "error", error = vim.inspect(jsn.error) }) + return + end + ---@cast jsn AvanteOpenAIChatResponse + if not jsn.choices then + return + end + local choice = jsn.choices[1] + if not choice then + return + end + local delta = choice.delta + if not delta then + local provider_conf = Providers.parse_config(self) + if provider_conf.model:match("o1") then + delta = choice.message + end + end + if not delta then + return + end + if delta.reasoning_content and delta.reasoning_content ~= vim.NIL and delta.reasoning_content ~= "" then + if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then + ctx.returned_think_start_tag = true + if opts.on_chunk then + opts.on_chunk("\n") + end + end + ctx.last_think_content = delta.reasoning_content + self:add_thinking_message(ctx, delta.reasoning_content, "generating", opts) + if opts.on_chunk then + opts.on_chunk(delta.reasoning_content) + end + elseif delta.reasoning and delta.reasoning ~= vim.NIL then + if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then + ctx.returned_think_start_tag = true + if opts.on_chunk then + opts.on_chunk("\n") + end + end + ctx.last_think_content = delta.reasoning + self:add_thinking_message(ctx, delta.reasoning, "generating", opts) + if opts.on_chunk then + opts.on_chunk(delta.reasoning) + end + elseif delta.tool_calls and delta.tool_calls ~= vim.NIL then + local choice_index = choice.index or 0 + for idx, tool_call in ipairs(delta.tool_calls) do + --- In Gemini's so-called OpenAI Compatible API, tool_call.index is nil, which is quite absurd! Therefore, a compatibility fix is needed here. + if tool_call.index == nil then + tool_call.index = choice_index + idx - 1 + end + if not ctx.tool_use_map then + ctx.tool_use_map = {} + end + local tool_key = tostring(tool_call.index) + local prev_tool_key = tostring(tool_call.index - 1) + if not ctx.tool_use_map[tool_key] then + local prev_tool_use = ctx.tool_use_map[prev_tool_key] + if tool_call.index > 0 and prev_tool_use then + self:add_tool_use_message(ctx, prev_tool_use, "generated", opts) + end + local tool_use = { + name = tool_call["function"].name, + id = tool_call.id, + input_json = type(tool_call["function"].arguments) == "string" and tool_call["function"].arguments + or "", + } + ctx.tool_use_map[tool_key] = tool_use + self:add_tool_use_message(ctx, tool_use, "generating", opts) + else + local tool_use = ctx.tool_use_map[tool_key] + if tool_call["function"].arguments == vim.NIL then + tool_call["function"].arguments = "" + end + tool_use.input_json = tool_use.input_json .. tool_call["function"].arguments + -- self:add_tool_use_message(ctx, tool_use, "generating", opts) + end + end + elseif delta.content then + if + ctx.returned_think_start_tag ~= nil + and (ctx.returned_think_end_tag == nil or not ctx.returned_think_end_tag) + then + ctx.returned_think_end_tag = true + if opts.on_chunk then + if + ctx.last_think_content + and ctx.last_think_content ~= vim.NIL + and ctx.last_think_content:sub(-1) ~= "\n" + then + opts.on_chunk("\n\n") + else + opts.on_chunk("\n") + end + end + self:add_thinking_message(ctx, "", "generated", opts) + end + if delta.content ~= vim.NIL then + if opts.on_chunk then + opts.on_chunk(delta.content) + end + self:add_text_message(ctx, delta.content, "generating", opts) + end + end + if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" or choice.finish_reason == "length" then + self:finish_pending_messages(ctx, opts) + if ctx.tool_use_map and vim.tbl_count(ctx.tool_use_map) > 0 then + opts.on_stop({ reason = "tool_use", usage = self.transform_openai_usage(jsn.usage) }) + else + opts.on_stop({ reason = "complete", usage = self.transform_openai_usage(jsn.usage) }) + end + end + if choice.finish_reason == "tool_calls" then + self:finish_pending_messages(ctx, opts) + opts.on_stop({ + reason = "tool_use", + usage = self.transform_openai_usage(jsn.usage), + }) + end +end + +function M:parse_response_without_stream(data, _, opts) + ---@type AvanteOpenAIChatResponse + local json = vim.json.decode(data) + if json.choices and json.choices[1] then + local choice = json.choices[1] + if choice.message and choice.message.content then + if opts.on_chunk then + opts.on_chunk(choice.message.content) + end + self:add_text_message({}, choice.message.content, "generated", opts) + vim.schedule(function() + opts.on_stop({ reason = "complete" }) + end) + end + end +end + +---@param prompt_opts AvantePromptOptions +---@return AvanteCurlOutput|nil +function M:parse_curl_args(prompt_opts) + local provider_conf, request_body = Providers.parse_config(self) + local disable_tools = provider_conf.disable_tools or false + + local headers = { + ["Content-Type"] = "application/json", + } + + if Providers.env.require_api_key(provider_conf) then + local api_key = self.parse_api_key() + if api_key == nil then + Utils.error( + Config.provider .. ": API key is not set, please set it in your environment variable or config file" + ) + return nil + end + headers["Authorization"] = "Bearer " .. api_key + end + + if M.is_openrouter(provider_conf.endpoint) then + headers["HTTP-Referer"] = "https://github.com/yetone/avante.nvim" + headers["X-Title"] = "Avante.nvim" + request_body.include_reasoning = true + end + + self.set_allowed_params(provider_conf, request_body) + local use_response_api = Providers.resolve_use_response_api(provider_conf, prompt_opts) + + local use_ReAct_prompt = provider_conf.use_ReAct_prompt == true + + local tools = nil + if not disable_tools and prompt_opts.tools and not use_ReAct_prompt then + tools = {} + for _, tool in ipairs(prompt_opts.tools) do + local transformed_tool = self:transform_tool(tool) + -- Response API uses flattened tool structure + if use_response_api then + -- Convert from {type: "function", function: {name, description, parameters}} + -- to {type: "function", name, description, parameters} + if transformed_tool.type == "function" and transformed_tool["function"] then + transformed_tool = { + type = "function", + name = transformed_tool["function"].name, + description = transformed_tool["function"].description, + parameters = transformed_tool["function"].parameters, + } + end + end + table.insert(tools, transformed_tool) + end + end + + Utils.debug("endpoint", provider_conf.endpoint) + Utils.debug("model", provider_conf.model) + + local stop = nil + if use_ReAct_prompt then + stop = { "" } + end + + -- Determine endpoint path based on use_response_api + local endpoint_path = use_response_api and "/responses" or "/chat/completions" + + local parsed_messages = self:parse_messages(prompt_opts) + + -- Build base body + local base_body = { + model = provider_conf.model, + stop = stop, + stream = true, + tools = tools, + } + + -- Response API uses 'input' instead of 'messages' + if use_response_api then + -- Check if we have tool results - if so, use previous_response_id + local has_function_outputs = false + for _, msg in ipairs(parsed_messages) do + if msg.type == "function_call_output" then + has_function_outputs = true + break + end + end + + if has_function_outputs and self.last_response_id then + -- When sending function outputs, use previous_response_id + base_body.previous_response_id = self.last_response_id + -- Only send the function outputs, not the full history + local function_outputs = {} + for _, msg in ipairs(parsed_messages) do + if msg.type == "function_call_output" then + table.insert(function_outputs, msg) + end + end + base_body.input = function_outputs + -- Clear the stored response_id after using it + self.last_response_id = nil + else + -- Normal request without tool results + base_body.input = parsed_messages + end + + -- Response API uses max_output_tokens instead of max_tokens/max_completion_tokens + if request_body.max_completion_tokens then + request_body.max_output_tokens = request_body.max_completion_tokens + request_body.max_completion_tokens = nil + end + if request_body.max_tokens then + request_body.max_output_tokens = request_body.max_tokens + request_body.max_tokens = nil + end + -- Response API doesn't use stream_options + base_body.stream_options = nil + else + base_body.messages = parsed_messages + base_body.stream_options = not M.is_mistral(provider_conf.endpoint) and { + include_usage = true, + } or nil + end + + return { + url = Utils.url_join(provider_conf.endpoint, endpoint_path), + proxy = provider_conf.proxy, + insecure = provider_conf.allow_insecure, + headers = Utils.tbl_override(headers, self.extra_headers), + body = vim.tbl_deep_extend("force", base_body, request_body), + } +end + +return M