diff --git a/lua/avante/config.lua b/lua/avante/config.lua index e22dd29..d2a50d6 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -26,7 +26,7 @@ M._defaults = { ---@alias avante.Mode "agentic" | "legacy" ---@type avante.Mode mode = "agentic", - ---@alias avante.ProviderName "claude" | "openai" | "azure" | "gemini" | "vertex" | "cohere" | "copilot" | "bedrock" | "ollama" | string + ---@alias avante.ProviderName "claude" | "openai" | "azure" | "gemini" | "vertex" | "cohere" | "copilot" | "bedrock" | "ollama" | "watsonx_code_assistant" | string ---@type avante.ProviderName provider = "claude", -- WARNING: Since auto-suggestions are a high-frequency operation and therefore expensive, @@ -348,6 +348,16 @@ M._defaults = { }, }, }, + ---@type AvanteSupportedProvider + watsonx_code_assistant = { + endpoint = "https://api.dataplatform.cloud.ibm.com/v2/wca/core/chat/text/generation", + model = "granite-8b-code-instruct", + timeout = 30000, -- Timeout in milliseconds + extra_request_body = { + -- Additional watsonx-specific parameters can be added here + }, + }, + ---@type AvanteSupportedProvider vertex_claude = { endpoint = "https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/publishers/antrhopic/models", diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 69edc97..446d270 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -583,8 +583,33 @@ function M.curl(opts) local curl_body_file = temp_file .. "-request-body.json" local resp_body_file = temp_file .. "-response-body.txt" local headers_file = temp_file .. "-response-headers.txt" - local json_content = vim.json.encode(spec.body) - fn.writefile(vim.split(json_content, "\n"), curl_body_file) + + -- Check if this is a multipart form request (specifically for watsonx) + local is_multipart_form = spec.headers and spec.headers["Content-Type"] == "multipart/form-data" + local curl_options + + if is_multipart_form then + -- For multipart form data, use the form parameter + -- spec.body should be a table with form field data + curl_options = { + headers = spec.headers, + proxy = spec.proxy, + insecure = spec.insecure, + form = spec.body, + raw = spec.rawArgs, + } + else + -- For regular JSON requests, encode as JSON and write to file + local json_content = vim.json.encode(spec.body) + fn.writefile(vim.split(json_content, "\n"), curl_body_file) + curl_options = { + headers = spec.headers, + proxy = spec.proxy, + insecure = spec.insecure, + body = curl_body_file, + raw = spec.rawArgs, + } + end Utils.debug("curl request body file:", curl_body_file) Utils.debug("curl response body file:", resp_body_file) @@ -599,122 +624,121 @@ function M.curl(opts) local headers_reported = false - local started_job, new_active_job = pcall(curl.post, spec.url, { - headers = spec.headers, - proxy = spec.proxy, - insecure = spec.insecure, - body = curl_body_file, - raw = spec.rawArgs, - dump = { "-D", headers_file }, - stream = function(err, data, _) - if not headers_reported and opts.on_response_headers then - headers_reported = true - opts.on_response_headers(parse_headers(headers_file)) - end - if err then - completed = true - handler_opts.on_stop({ reason = "error", error = err }) - return - end - if not data then return end - if Config.debug then - if type(data) == "string" then - local file = io.open(resp_body_file, "a") - if file then - file:write(data .. "\n") - file:close() - end + local started_job, new_active_job = pcall( + curl.post, + spec.url, + vim.tbl_extend("force", curl_options, { + dump = { "-D", headers_file }, + stream = function(err, data, _) + if not headers_reported and opts.on_response_headers then + headers_reported = true + opts.on_response_headers(parse_headers(headers_file)) end - end - vim.schedule(function() - if provider.parse_stream_data ~= nil then - provider:parse_stream_data(turn_ctx, data, handler_opts) - else - parse_stream_data(data) - end - end) - end, - on_error = function(err) - if err.exit == 23 then - local xdg_runtime_dir = os.getenv("XDG_RUNTIME_DIR") - if not xdg_runtime_dir or fn.isdirectory(xdg_runtime_dir) == 0 then - Utils.error( - "$XDG_RUNTIME_DIR=" - .. xdg_runtime_dir - .. " is set but does not exist. curl could not write output. Please make sure it exists, or unset.", - { title = "Avante" } - ) - elseif not uv.fs_access(xdg_runtime_dir, "w") then - Utils.error( - "$XDG_RUNTIME_DIR=" - .. xdg_runtime_dir - .. " exists but is not writable. curl could not write output. Please make sure it is writable, or unset.", - { title = "Avante" } - ) - end - end - - active_job = nil - if not completed then - completed = true - cleanup() - handler_opts.on_stop({ reason = "error", error = err }) - end - end, - callback = function(result) - active_job = nil - cleanup() - local headers_map = vim.iter(result.headers):fold({}, function(acc, value) - local pieces = vim.split(value, ":") - local key = pieces[1] - local remain = vim.list_slice(pieces, 2) - if not remain then return acc end - local val = Utils.trim_spaces(table.concat(remain, ":")) - acc[key] = val - return acc - end) - if result.status >= 400 then - if provider.on_error then - provider.on_error(result) - else - Utils.error("API request failed with status " .. result.status, { once = true, title = "Avante" }) - end - local retry_after = 10 - if headers_map["retry-after"] then retry_after = tonumber(headers_map["retry-after"]) or 10 end - if result.status == 429 then - handler_opts.on_stop({ reason = "rate_limit", retry_after = retry_after }) + if err then + completed = true + handler_opts.on_stop({ reason = "error", error = err }) return end + if not data then return end + if Config.debug then + if type(data) == "string" then + local file = io.open(resp_body_file, "a") + if file then + file:write(data .. "\n") + file:close() + end + end + end vim.schedule(function() - if not completed then - completed = true - handler_opts.on_stop({ - reason = "error", - error = "API request failed with status " .. result.status .. ". Body: " .. vim.inspect(result.body), - }) + if provider.parse_stream_data ~= nil then + provider:parse_stream_data(turn_ctx, data, handler_opts) + else + parse_stream_data(data) end end) - end - - -- If stream is not enabled, then handle the response here - if provider:is_disable_stream() and result.status == 200 then - vim.schedule(function() - completed = true - parse_response_without_stream(result.body) - end) - end - - if result.status == 200 and spec.url:match("https://openrouter.ai") then - local content_type = headers_map["content-type"] - if content_type and content_type:match("text/html") then - handler_opts.on_stop({ - reason = "error", - error = "Your openrouter endpoint setting is incorrect, please set it to https://openrouter.ai/api/v1", - }) + end, + on_error = function(err) + if err.exit == 23 then + local xdg_runtime_dir = os.getenv("XDG_RUNTIME_DIR") + if not xdg_runtime_dir or fn.isdirectory(xdg_runtime_dir) == 0 then + Utils.error( + "$XDG_RUNTIME_DIR=" + .. xdg_runtime_dir + .. " is set but does not exist. curl could not write output. Please make sure it exists, or unset.", + { title = "Avante" } + ) + elseif not uv.fs_access(xdg_runtime_dir, "w") then + Utils.error( + "$XDG_RUNTIME_DIR=" + .. xdg_runtime_dir + .. " exists but is not writable. curl could not write output. Please make sure it is writable, or unset.", + { title = "Avante" } + ) + end end - end - end, - }) + + active_job = nil + if not completed then + completed = true + cleanup() + handler_opts.on_stop({ reason = "error", error = err }) + end + end, + callback = function(result) + active_job = nil + cleanup() + local headers_map = vim.iter(result.headers):fold({}, function(acc, value) + local pieces = vim.split(value, ":") + local key = pieces[1] + local remain = vim.list_slice(pieces, 2) + if not remain then return acc end + local val = Utils.trim_spaces(table.concat(remain, ":")) + acc[key] = val + return acc + end) + if result.status >= 400 then + if provider.on_error then + provider.on_error(result) + else + Utils.error("API request failed with status " .. result.status, { once = true, title = "Avante" }) + end + local retry_after = 10 + if headers_map["retry-after"] then retry_after = tonumber(headers_map["retry-after"]) or 10 end + if result.status == 429 then + handler_opts.on_stop({ reason = "rate_limit", retry_after = retry_after }) + return + end + vim.schedule(function() + if not completed then + completed = true + handler_opts.on_stop({ + reason = "error", + error = "API request failed with status " .. result.status .. ". Body: " .. vim.inspect(result.body), + }) + end + end) + end + + -- If stream is not enabled, then handle the response here + if provider:is_disable_stream() and result.status == 200 then + vim.schedule(function() + completed = true + parse_response_without_stream(result.body) + end) + end + + if result.status == 200 and spec.url:match("https://openrouter.ai") then + local content_type = headers_map["content-type"] + if content_type and content_type:match("text/html") then + handler_opts.on_stop({ + reason = "error", + error = "Your openrouter endpoint setting is incorrect, please set it to https://openrouter.ai/api/v1", + }) + end + end + end, + }) + ) if not started_job then local error_msg = vim.inspect(new_active_job) diff --git a/lua/avante/providers/init.lua b/lua/avante/providers/init.lua index 8c59c1a..4c2ee72 100644 --- a/lua/avante/providers/init.lua +++ b/lua/avante/providers/init.lua @@ -13,6 +13,7 @@ local Utils = require("avante.utils") ---@field bedrock AvanteBedrockProviderFunctor ---@field ollama AvanteProviderFunctor ---@field vertex_claude AvanteProviderFunctor +---@field watsonx_code_assistant AvanteProviderFunctor local M = {} ---@class EnvironmentHandler diff --git a/lua/avante/providers/watsonx_code_assistant.lua b/lua/avante/providers/watsonx_code_assistant.lua new file mode 100644 index 0000000..8e0273e --- /dev/null +++ b/lua/avante/providers/watsonx_code_assistant.lua @@ -0,0 +1,285 @@ +-- Documentation for setting up IBM Watsonx Code Assistant +--- Generating an access token: https://www.ibm.com/products/watsonx-code-assistant or https://github.ibm.com/code-assistant/wca-api +local P = require("avante.providers") +local Utils = require("avante.utils") +local curl = require("plenary.curl") +local Config = require("avante.config") +local Llm = require("avante.llm") +local ts_utils = pcall(require, "nvim-treesitter.ts_utils") and require("nvim-treesitter.ts_utils") + or { + get_node_at_cursor = function() return nil end, + } +local OpenAI = require("avante.providers.openai") + +---@class AvanteProviderFunctor +local M = {} + +M.api_key_name = "WCA_API_KEY" -- The name of the environment variable that contains the API key +M.role_map = { + user = "USER", + assistant = "ASSISTANT", + system = "SYSTEM", +} +M.last_iam_token_time = nil +M.iam_bearer_token = "" + +function M:is_disable_stream() return true end + +---@type fun(self: AvanteProviderFunctor, opts: AvantePromptOptions): table +function M:parse_messages(opts) + if opts == nil then return {} end + local messages + if opts.system_prompt == "WCA_COMMAND" then + messages = {} + else + messages = { + { content = opts.system_prompt, role = "SYSTEM" }, + } + end + vim + .iter(opts.messages) + :each(function(msg) table.insert(messages, { content = msg.content, role = M.role_map[msg.role] }) end) + return messages +end + +--- This function will be used to parse incoming SSE stream +--- It takes in the data stream as the first argument, followed by SSE event state, and opts +--- retrieved from given buffer. +--- This opts include: +--- - on_chunk: (fun(chunk: string): any) this is invoked on parsing correct delta chunk +--- - on_complete: (fun(err: string|nil): any) this is invoked on either complete call or error chunk +local function parse_response_wo_stream(self, data, _, opts) + if Utils.debug then Utils.debug("WCA parse_response_without_stream called with opts: " .. vim.inspect(opts)) end + + local json = vim.json.decode(data) + if Utils.debug then Utils.debug("WCA Response: " .. vim.inspect(json)) end + if json.error ~= nil and json.error ~= vim.NIL then + Utils.warn("WCA Error " .. tostring(json.error.code) .. ": " .. tostring(json.error.message)) + end + if json.response and json.response.message and json.response.message.content then + local content = json.response.message.content + + if Utils.debug then Utils.debug("WCA Original Content: " .. tostring(content)) end + + -- Clean up the content by removing XML-like tags that are not part of the actual response + -- These tags appear to be internal formatting from watsonx that should not be shown to users + -- Use more careful patterns to avoid removing too much content + content = content:gsub("\n?", "") + content = content:gsub("\n?", "") + content = content:gsub("\n?.-\n?", "") + content = content:gsub("\n?.-\n?", "") + content = content:gsub("\n?.-\n?", "") + + -- Trim excessive whitespace but preserve structure + content = content:gsub("^\n+", ""):gsub("\n+$", "") + + if Utils.debug then Utils.debug("WCA Cleaned Content: " .. tostring(content)) end + + -- Ensure we still have content after cleaning + if content and content ~= "" then + if opts.on_chunk then opts.on_chunk(content) end + -- Add the text message for UI display (similar to OpenAI provider) + OpenAI:add_text_message({}, content, "generated", opts) + else + Utils.warn("WCA: Content became empty after cleaning") + if opts.on_chunk then + opts.on_chunk(json.response.message.content) -- Fallback to original content + end + -- Add the original content as fallback + OpenAI:add_text_message({}, json.response.message.content, "generated", opts) + end + vim.schedule(function() + if opts.on_stop then opts.on_stop({ reason = "complete" }) end + end) + elseif json.error and json.error ~= vim.NIL then + vim.schedule(function() + if opts.on_stop then + opts.on_stop({ + reason = "error", + error = "WCA Error " .. tostring(json.error.code) .. ": " .. tostring(json.error.message), + }) + end + end) + else + -- Handle case where there's no response content and no explicit error + if Utils.debug then Utils.debug("WCA: No content found in response, treating as empty response") end + vim.schedule(function() + if opts.on_stop then opts.on_stop({ reason = "complete" }) end + end) + end +end + +M.parse_response_without_stream = parse_response_wo_stream + +-- Needs to be language specific for each function and methods. +local get_function_name_under_cursor = function() + local current_node = ts_utils.get_node_at_cursor() + if not current_node then return "" end + local expr = current_node + + while expr do + if expr:type() == "function_definition" or expr:type() == "method_declaration" then break end + expr = expr:parent() + end + + if not expr then return "" end + + local result = (ts_utils.get_node_text(expr:child(1)))[1] + return result +end + +--- It takes in the provider options as the first argument, followed by code_opts retrieved from given buffer. +---@type fun(command_name: string): nil +M.method_command = function(command_name) + if + command_name ~= "document" + and command_name ~= "unit-test" + and command_name ~= "explain" + and command_name:find("translate", 1, true) == 0 + then + Utils.warn("Invalid command name" .. command_name) + end + + local current_buffer = vim.api.nvim_get_current_buf() + local file_path = vim.api.nvim_buf_get_name(current_buffer) + + -- Use file name for now. For proper extraction of method names, a lang specific TreeSitter querry is need + -- local method_name = get_function_name_under_cursor() + -- use whole file if we cannot get the method + local method_name = "" + if method_name == "" then + local path_splits = vim.split(file_path, "/") + method_name = path_splits[#path_splits] + end + + local sidebar = require("avante").get() + if not sidebar then + require("avante.api").ask() + sidebar = require("avante").get() + end + if not sidebar:is_open() then sidebar:open({}) end + sidebar.file_selector:add_current_buffer() + + local response_content = "" + local provider = P[Config.provider] + local content = "/" .. command_name .. " @" .. method_name + Llm.curl({ + provider = provider, + prompt_opts = { + system_prompt = "WCA_COMMAND", + messages = { + { content = content, role = "user" }, + }, + selected_files = sidebar.file_selector:get_selected_files_contents(), + }, + handler_opts = { + on_start = function(_) end, + on_chunk = function(chunk) + if not chunk then return end + response_content = response_content .. chunk + end, + on_stop = function(stop_opts) + if stop_opts.error ~= nil then + Utils.error(string.format("WCA Command " .. command_name .. " failed: %s", vim.inspect(stop_opts.error))) + return + end + if stop_opts.reason == "complete" then + if not sidebar:is_open() then sidebar:open({}) end + sidebar:update_content(response_content, { focus = true }) + end + end, + }, + }) +end + +local function get_iam_bearer_token(provider) + if M.last_iam_token_time ~= nil and os.time() - M.last_iam_token_time <= 3550 then return M.iam_bearer_token end + + local api_key = provider.parse_api_key() + if api_key == nil then + -- if no api key is available, make a request with a empty api key. + api_key = "" + end + + local url = "https://iam.cloud.ibm.com/identity/token" + local header = { ["Content-Type"] = "application/x-www-form-urlencoded" } + local body = "grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey=" .. api_key + + local response = curl.post(url, { headers = header, body = body }) + if response.status == 200 then + -- select first key value pair + local access_token_field = vim.split(response.body, ",")[1] + -- get value + local token = vim.split(access_token_field, ":")[2] + -- remove quotes + M.iam_bearer_token = (token:gsub("^%p(.*)%p$", "%1")) + M.last_iam_token_time = os.time() + else + Utils.error( + "Failed to retrieve IAM token: " .. response.status .. ": " .. vim.inspect(response.body), + { title = "Avante WCA" } + ) + M.iam_bearer_token = "" + end + return M.iam_bearer_token +end + +local random = math.random +math.randomseed(os.time()) +local function uuid() + local template = "xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx" + return string.gsub(template, "[xy]", function(c) + local v = (c == "x") and random(0, 0xf) or random(8, 0xb) + return string.format("%x", v) + end) +end + +--- This function below will be used to parse in cURL arguments. +--- It takes in the provider options as the first argument, followed by code_opts retrieved from given buffer. +--- This code_opts include: +--- - question: Input from the users +--- - code_lang: the language of given code buffer +--- - code_content: content of code buffer +--- - selected_code_content: (optional) If given code content is selected in visual mode as context. +---@type fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput +---@param provider AvanteProviderFunctor +---@param code_opts AvantePromptOptions +---@return table +M.parse_curl_args = function(provider, code_opts) + local base, _ = P.parse_config(provider) + local headers = { + ["Content-Type"] = "multipart/form-data", + ["Authorization"] = "Bearer " .. get_iam_bearer_token(provider), + ["Request-ID"] = uuid(), + } + + -- Create the message_payload structure as required by WCA API + local message_payload = { + message_payload = { + chat_session_id = uuid(), -- Required for granite-3-8b-instruct model + messages = M:parse_messages(code_opts), + }, + } + + -- Base64 encode the message payload as required by watsonx API + local json_content = vim.json.encode(message_payload) + local encoded_json_content = vim.base64.encode(json_content) + + -- Return form data structure - the message field contains the base64-encoded JSON + local body = { + message = encoded_json_content, + } + + return { + url = base.endpoint, + timeout = base.timeout, + insecure = false, + headers = headers, + body = body, + } +end + +--- The following function SHOULD only be used when providers doesn't follow SSE spec [ADVANCED] +--- this is mutually exclusive with parse_response_data + +return M diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index fdcb7fa..efd2fc0 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -2543,7 +2543,7 @@ function Sidebar:create_input_container() vim.keymap.del("n", "G", { buffer = self.containers.result.bufnr }) end) - if stop_opts.error ~= nil then + if stop_opts.error ~= nil and stop_opts.error ~= vim.NIL then local msg_content = stop_opts.error if type(msg_content) ~= "string" then msg_content = vim.inspect(msg_content) end self:add_history_messages({ diff --git a/tests/providers/watsonx_code_assistant_spec.lua b/tests/providers/watsonx_code_assistant_spec.lua new file mode 100644 index 0000000..dad67cb --- /dev/null +++ b/tests/providers/watsonx_code_assistant_spec.lua @@ -0,0 +1,69 @@ +local busted = require("plenary.busted") + +busted.describe("watsonx_code_assistant provider", function() + local watsonx_provider + + busted.before_each(function() + -- Minimal setup without extensive mocking + watsonx_provider = require("avante.providers.watsonx_code_assistant") + end) + + busted.describe("basic configuration", function() + busted.it("should have required properties", function() + assert.is_not_nil(watsonx_provider.api_key_name) + assert.equals("WCA_API_KEY", watsonx_provider.api_key_name) + assert.is_not_nil(watsonx_provider.role_map) + assert.equals("USER", watsonx_provider.role_map.user) + assert.equals("ASSISTANT", watsonx_provider.role_map.assistant) + end) + + busted.it("should disable streaming", function() assert.is_true(watsonx_provider:is_disable_stream()) end) + + busted.it("should have required functions", function() + assert.is_function(watsonx_provider.parse_messages) + assert.is_function(watsonx_provider.parse_response_without_stream) + assert.is_function(watsonx_provider.parse_curl_args) + end) + end) + + busted.describe("parse_messages", function() + busted.it("should parse messages with correct role mapping", function() + ---@type AvantePromptOptions + local opts = { + system_prompt = "You are a helpful assistant", + messages = { + { content = "Hello", role = "user" }, + { content = "Hi there", role = "assistant" }, + }, + } + + local result = watsonx_provider:parse_messages(opts) + + assert.is_table(result) + assert.equals(3, #result) -- system + 2 messages + assert.equals("SYSTEM", result[1].role) + assert.equals("You are a helpful assistant", result[1].content) + assert.equals("USER", result[2].role) + assert.equals("Hello", result[2].content) + assert.equals("ASSISTANT", result[3].role) + assert.equals("Hi there", result[3].content) + end) + + busted.it("should handle WCA_COMMAND system prompt", function() + ---@type AvantePromptOptions + local opts = { + system_prompt = "WCA_COMMAND", + messages = { + { content = "/document main.py", role = "user" }, + }, + } + + local result = watsonx_provider:parse_messages(opts) + + assert.is_table(result) + assert.equals(1, #result) -- only user message, no system prompt + assert.equals("USER", result[1].role) + assert.equals("/document main.py", result[1].content) + end) + end) +end)