[feat] Add basic support for IBM's watsonx code assistant (#2617)
Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
This commit is contained in:
@@ -26,7 +26,7 @@ M._defaults = {
|
||||
---@alias avante.Mode "agentic" | "legacy"
|
||||
---@type avante.Mode
|
||||
mode = "agentic",
|
||||
---@alias avante.ProviderName "claude" | "openai" | "azure" | "gemini" | "vertex" | "cohere" | "copilot" | "bedrock" | "ollama" | string
|
||||
---@alias avante.ProviderName "claude" | "openai" | "azure" | "gemini" | "vertex" | "cohere" | "copilot" | "bedrock" | "ollama" | "watsonx_code_assistant" | string
|
||||
---@type avante.ProviderName
|
||||
provider = "claude",
|
||||
-- WARNING: Since auto-suggestions are a high-frequency operation and therefore expensive,
|
||||
@@ -348,6 +348,16 @@ M._defaults = {
|
||||
},
|
||||
},
|
||||
},
|
||||
---@type AvanteSupportedProvider
|
||||
watsonx_code_assistant = {
|
||||
endpoint = "https://api.dataplatform.cloud.ibm.com/v2/wca/core/chat/text/generation",
|
||||
model = "granite-8b-code-instruct",
|
||||
timeout = 30000, -- Timeout in milliseconds
|
||||
extra_request_body = {
|
||||
-- Additional watsonx-specific parameters can be added here
|
||||
},
|
||||
},
|
||||
|
||||
---@type AvanteSupportedProvider
|
||||
vertex_claude = {
|
||||
endpoint = "https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/publishers/antrhopic/models",
|
||||
|
||||
@@ -583,8 +583,33 @@ function M.curl(opts)
|
||||
local curl_body_file = temp_file .. "-request-body.json"
|
||||
local resp_body_file = temp_file .. "-response-body.txt"
|
||||
local headers_file = temp_file .. "-response-headers.txt"
|
||||
local json_content = vim.json.encode(spec.body)
|
||||
fn.writefile(vim.split(json_content, "\n"), curl_body_file)
|
||||
|
||||
-- Check if this is a multipart form request (specifically for watsonx)
|
||||
local is_multipart_form = spec.headers and spec.headers["Content-Type"] == "multipart/form-data"
|
||||
local curl_options
|
||||
|
||||
if is_multipart_form then
|
||||
-- For multipart form data, use the form parameter
|
||||
-- spec.body should be a table with form field data
|
||||
curl_options = {
|
||||
headers = spec.headers,
|
||||
proxy = spec.proxy,
|
||||
insecure = spec.insecure,
|
||||
form = spec.body,
|
||||
raw = spec.rawArgs,
|
||||
}
|
||||
else
|
||||
-- For regular JSON requests, encode as JSON and write to file
|
||||
local json_content = vim.json.encode(spec.body)
|
||||
fn.writefile(vim.split(json_content, "\n"), curl_body_file)
|
||||
curl_options = {
|
||||
headers = spec.headers,
|
||||
proxy = spec.proxy,
|
||||
insecure = spec.insecure,
|
||||
body = curl_body_file,
|
||||
raw = spec.rawArgs,
|
||||
}
|
||||
end
|
||||
|
||||
Utils.debug("curl request body file:", curl_body_file)
|
||||
Utils.debug("curl response body file:", resp_body_file)
|
||||
@@ -599,122 +624,121 @@ function M.curl(opts)
|
||||
|
||||
local headers_reported = false
|
||||
|
||||
local started_job, new_active_job = pcall(curl.post, spec.url, {
|
||||
headers = spec.headers,
|
||||
proxy = spec.proxy,
|
||||
insecure = spec.insecure,
|
||||
body = curl_body_file,
|
||||
raw = spec.rawArgs,
|
||||
dump = { "-D", headers_file },
|
||||
stream = function(err, data, _)
|
||||
if not headers_reported and opts.on_response_headers then
|
||||
headers_reported = true
|
||||
opts.on_response_headers(parse_headers(headers_file))
|
||||
end
|
||||
if err then
|
||||
completed = true
|
||||
handler_opts.on_stop({ reason = "error", error = err })
|
||||
return
|
||||
end
|
||||
if not data then return end
|
||||
if Config.debug then
|
||||
if type(data) == "string" then
|
||||
local file = io.open(resp_body_file, "a")
|
||||
if file then
|
||||
file:write(data .. "\n")
|
||||
file:close()
|
||||
end
|
||||
local started_job, new_active_job = pcall(
|
||||
curl.post,
|
||||
spec.url,
|
||||
vim.tbl_extend("force", curl_options, {
|
||||
dump = { "-D", headers_file },
|
||||
stream = function(err, data, _)
|
||||
if not headers_reported and opts.on_response_headers then
|
||||
headers_reported = true
|
||||
opts.on_response_headers(parse_headers(headers_file))
|
||||
end
|
||||
end
|
||||
vim.schedule(function()
|
||||
if provider.parse_stream_data ~= nil then
|
||||
provider:parse_stream_data(turn_ctx, data, handler_opts)
|
||||
else
|
||||
parse_stream_data(data)
|
||||
end
|
||||
end)
|
||||
end,
|
||||
on_error = function(err)
|
||||
if err.exit == 23 then
|
||||
local xdg_runtime_dir = os.getenv("XDG_RUNTIME_DIR")
|
||||
if not xdg_runtime_dir or fn.isdirectory(xdg_runtime_dir) == 0 then
|
||||
Utils.error(
|
||||
"$XDG_RUNTIME_DIR="
|
||||
.. xdg_runtime_dir
|
||||
.. " is set but does not exist. curl could not write output. Please make sure it exists, or unset.",
|
||||
{ title = "Avante" }
|
||||
)
|
||||
elseif not uv.fs_access(xdg_runtime_dir, "w") then
|
||||
Utils.error(
|
||||
"$XDG_RUNTIME_DIR="
|
||||
.. xdg_runtime_dir
|
||||
.. " exists but is not writable. curl could not write output. Please make sure it is writable, or unset.",
|
||||
{ title = "Avante" }
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
active_job = nil
|
||||
if not completed then
|
||||
completed = true
|
||||
cleanup()
|
||||
handler_opts.on_stop({ reason = "error", error = err })
|
||||
end
|
||||
end,
|
||||
callback = function(result)
|
||||
active_job = nil
|
||||
cleanup()
|
||||
local headers_map = vim.iter(result.headers):fold({}, function(acc, value)
|
||||
local pieces = vim.split(value, ":")
|
||||
local key = pieces[1]
|
||||
local remain = vim.list_slice(pieces, 2)
|
||||
if not remain then return acc end
|
||||
local val = Utils.trim_spaces(table.concat(remain, ":"))
|
||||
acc[key] = val
|
||||
return acc
|
||||
end)
|
||||
if result.status >= 400 then
|
||||
if provider.on_error then
|
||||
provider.on_error(result)
|
||||
else
|
||||
Utils.error("API request failed with status " .. result.status, { once = true, title = "Avante" })
|
||||
end
|
||||
local retry_after = 10
|
||||
if headers_map["retry-after"] then retry_after = tonumber(headers_map["retry-after"]) or 10 end
|
||||
if result.status == 429 then
|
||||
handler_opts.on_stop({ reason = "rate_limit", retry_after = retry_after })
|
||||
if err then
|
||||
completed = true
|
||||
handler_opts.on_stop({ reason = "error", error = err })
|
||||
return
|
||||
end
|
||||
if not data then return end
|
||||
if Config.debug then
|
||||
if type(data) == "string" then
|
||||
local file = io.open(resp_body_file, "a")
|
||||
if file then
|
||||
file:write(data .. "\n")
|
||||
file:close()
|
||||
end
|
||||
end
|
||||
end
|
||||
vim.schedule(function()
|
||||
if not completed then
|
||||
completed = true
|
||||
handler_opts.on_stop({
|
||||
reason = "error",
|
||||
error = "API request failed with status " .. result.status .. ". Body: " .. vim.inspect(result.body),
|
||||
})
|
||||
if provider.parse_stream_data ~= nil then
|
||||
provider:parse_stream_data(turn_ctx, data, handler_opts)
|
||||
else
|
||||
parse_stream_data(data)
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
-- If stream is not enabled, then handle the response here
|
||||
if provider:is_disable_stream() and result.status == 200 then
|
||||
vim.schedule(function()
|
||||
completed = true
|
||||
parse_response_without_stream(result.body)
|
||||
end)
|
||||
end
|
||||
|
||||
if result.status == 200 and spec.url:match("https://openrouter.ai") then
|
||||
local content_type = headers_map["content-type"]
|
||||
if content_type and content_type:match("text/html") then
|
||||
handler_opts.on_stop({
|
||||
reason = "error",
|
||||
error = "Your openrouter endpoint setting is incorrect, please set it to https://openrouter.ai/api/v1",
|
||||
})
|
||||
end,
|
||||
on_error = function(err)
|
||||
if err.exit == 23 then
|
||||
local xdg_runtime_dir = os.getenv("XDG_RUNTIME_DIR")
|
||||
if not xdg_runtime_dir or fn.isdirectory(xdg_runtime_dir) == 0 then
|
||||
Utils.error(
|
||||
"$XDG_RUNTIME_DIR="
|
||||
.. xdg_runtime_dir
|
||||
.. " is set but does not exist. curl could not write output. Please make sure it exists, or unset.",
|
||||
{ title = "Avante" }
|
||||
)
|
||||
elseif not uv.fs_access(xdg_runtime_dir, "w") then
|
||||
Utils.error(
|
||||
"$XDG_RUNTIME_DIR="
|
||||
.. xdg_runtime_dir
|
||||
.. " exists but is not writable. curl could not write output. Please make sure it is writable, or unset.",
|
||||
{ title = "Avante" }
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
end,
|
||||
})
|
||||
|
||||
active_job = nil
|
||||
if not completed then
|
||||
completed = true
|
||||
cleanup()
|
||||
handler_opts.on_stop({ reason = "error", error = err })
|
||||
end
|
||||
end,
|
||||
callback = function(result)
|
||||
active_job = nil
|
||||
cleanup()
|
||||
local headers_map = vim.iter(result.headers):fold({}, function(acc, value)
|
||||
local pieces = vim.split(value, ":")
|
||||
local key = pieces[1]
|
||||
local remain = vim.list_slice(pieces, 2)
|
||||
if not remain then return acc end
|
||||
local val = Utils.trim_spaces(table.concat(remain, ":"))
|
||||
acc[key] = val
|
||||
return acc
|
||||
end)
|
||||
if result.status >= 400 then
|
||||
if provider.on_error then
|
||||
provider.on_error(result)
|
||||
else
|
||||
Utils.error("API request failed with status " .. result.status, { once = true, title = "Avante" })
|
||||
end
|
||||
local retry_after = 10
|
||||
if headers_map["retry-after"] then retry_after = tonumber(headers_map["retry-after"]) or 10 end
|
||||
if result.status == 429 then
|
||||
handler_opts.on_stop({ reason = "rate_limit", retry_after = retry_after })
|
||||
return
|
||||
end
|
||||
vim.schedule(function()
|
||||
if not completed then
|
||||
completed = true
|
||||
handler_opts.on_stop({
|
||||
reason = "error",
|
||||
error = "API request failed with status " .. result.status .. ". Body: " .. vim.inspect(result.body),
|
||||
})
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
-- If stream is not enabled, then handle the response here
|
||||
if provider:is_disable_stream() and result.status == 200 then
|
||||
vim.schedule(function()
|
||||
completed = true
|
||||
parse_response_without_stream(result.body)
|
||||
end)
|
||||
end
|
||||
|
||||
if result.status == 200 and spec.url:match("https://openrouter.ai") then
|
||||
local content_type = headers_map["content-type"]
|
||||
if content_type and content_type:match("text/html") then
|
||||
handler_opts.on_stop({
|
||||
reason = "error",
|
||||
error = "Your openrouter endpoint setting is incorrect, please set it to https://openrouter.ai/api/v1",
|
||||
})
|
||||
end
|
||||
end
|
||||
end,
|
||||
})
|
||||
)
|
||||
|
||||
if not started_job then
|
||||
local error_msg = vim.inspect(new_active_job)
|
||||
|
||||
@@ -13,6 +13,7 @@ local Utils = require("avante.utils")
|
||||
---@field bedrock AvanteBedrockProviderFunctor
|
||||
---@field ollama AvanteProviderFunctor
|
||||
---@field vertex_claude AvanteProviderFunctor
|
||||
---@field watsonx_code_assistant AvanteProviderFunctor
|
||||
local M = {}
|
||||
|
||||
---@class EnvironmentHandler
|
||||
|
||||
285
lua/avante/providers/watsonx_code_assistant.lua
Normal file
285
lua/avante/providers/watsonx_code_assistant.lua
Normal file
@@ -0,0 +1,285 @@
|
||||
-- Documentation for setting up IBM Watsonx Code Assistant
|
||||
--- Generating an access token: https://www.ibm.com/products/watsonx-code-assistant or https://github.ibm.com/code-assistant/wca-api
|
||||
local P = require("avante.providers")
|
||||
local Utils = require("avante.utils")
|
||||
local curl = require("plenary.curl")
|
||||
local Config = require("avante.config")
|
||||
local Llm = require("avante.llm")
|
||||
local ts_utils = pcall(require, "nvim-treesitter.ts_utils") and require("nvim-treesitter.ts_utils")
|
||||
or {
|
||||
get_node_at_cursor = function() return nil end,
|
||||
}
|
||||
local OpenAI = require("avante.providers.openai")
|
||||
|
||||
---@class AvanteProviderFunctor
|
||||
local M = {}
|
||||
|
||||
M.api_key_name = "WCA_API_KEY" -- The name of the environment variable that contains the API key
|
||||
M.role_map = {
|
||||
user = "USER",
|
||||
assistant = "ASSISTANT",
|
||||
system = "SYSTEM",
|
||||
}
|
||||
M.last_iam_token_time = nil
|
||||
M.iam_bearer_token = ""
|
||||
|
||||
function M:is_disable_stream() return true end
|
||||
|
||||
---@type fun(self: AvanteProviderFunctor, opts: AvantePromptOptions): table
|
||||
function M:parse_messages(opts)
|
||||
if opts == nil then return {} end
|
||||
local messages
|
||||
if opts.system_prompt == "WCA_COMMAND" then
|
||||
messages = {}
|
||||
else
|
||||
messages = {
|
||||
{ content = opts.system_prompt, role = "SYSTEM" },
|
||||
}
|
||||
end
|
||||
vim
|
||||
.iter(opts.messages)
|
||||
:each(function(msg) table.insert(messages, { content = msg.content, role = M.role_map[msg.role] }) end)
|
||||
return messages
|
||||
end
|
||||
|
||||
--- This function will be used to parse incoming SSE stream
|
||||
--- It takes in the data stream as the first argument, followed by SSE event state, and opts
|
||||
--- retrieved from given buffer.
|
||||
--- This opts include:
|
||||
--- - on_chunk: (fun(chunk: string): any) this is invoked on parsing correct delta chunk
|
||||
--- - on_complete: (fun(err: string|nil): any) this is invoked on either complete call or error chunk
|
||||
local function parse_response_wo_stream(self, data, _, opts)
|
||||
if Utils.debug then Utils.debug("WCA parse_response_without_stream called with opts: " .. vim.inspect(opts)) end
|
||||
|
||||
local json = vim.json.decode(data)
|
||||
if Utils.debug then Utils.debug("WCA Response: " .. vim.inspect(json)) end
|
||||
if json.error ~= nil and json.error ~= vim.NIL then
|
||||
Utils.warn("WCA Error " .. tostring(json.error.code) .. ": " .. tostring(json.error.message))
|
||||
end
|
||||
if json.response and json.response.message and json.response.message.content then
|
||||
local content = json.response.message.content
|
||||
|
||||
if Utils.debug then Utils.debug("WCA Original Content: " .. tostring(content)) end
|
||||
|
||||
-- Clean up the content by removing XML-like tags that are not part of the actual response
|
||||
-- These tags appear to be internal formatting from watsonx that should not be shown to users
|
||||
-- Use more careful patterns to avoid removing too much content
|
||||
content = content:gsub("<file>\n?", "")
|
||||
content = content:gsub("\n?</file>", "")
|
||||
content = content:gsub("\n?<memory>.-</memory>\n?", "")
|
||||
content = content:gsub("\n?<update_todo_status>.-</update_todo_status>\n?", "")
|
||||
content = content:gsub("\n?<attempt_completion>.-</attempt_completion>\n?", "")
|
||||
|
||||
-- Trim excessive whitespace but preserve structure
|
||||
content = content:gsub("^\n+", ""):gsub("\n+$", "")
|
||||
|
||||
if Utils.debug then Utils.debug("WCA Cleaned Content: " .. tostring(content)) end
|
||||
|
||||
-- Ensure we still have content after cleaning
|
||||
if content and content ~= "" then
|
||||
if opts.on_chunk then opts.on_chunk(content) end
|
||||
-- Add the text message for UI display (similar to OpenAI provider)
|
||||
OpenAI:add_text_message({}, content, "generated", opts)
|
||||
else
|
||||
Utils.warn("WCA: Content became empty after cleaning")
|
||||
if opts.on_chunk then
|
||||
opts.on_chunk(json.response.message.content) -- Fallback to original content
|
||||
end
|
||||
-- Add the original content as fallback
|
||||
OpenAI:add_text_message({}, json.response.message.content, "generated", opts)
|
||||
end
|
||||
vim.schedule(function()
|
||||
if opts.on_stop then opts.on_stop({ reason = "complete" }) end
|
||||
end)
|
||||
elseif json.error and json.error ~= vim.NIL then
|
||||
vim.schedule(function()
|
||||
if opts.on_stop then
|
||||
opts.on_stop({
|
||||
reason = "error",
|
||||
error = "WCA Error " .. tostring(json.error.code) .. ": " .. tostring(json.error.message),
|
||||
})
|
||||
end
|
||||
end)
|
||||
else
|
||||
-- Handle case where there's no response content and no explicit error
|
||||
if Utils.debug then Utils.debug("WCA: No content found in response, treating as empty response") end
|
||||
vim.schedule(function()
|
||||
if opts.on_stop then opts.on_stop({ reason = "complete" }) end
|
||||
end)
|
||||
end
|
||||
end
|
||||
|
||||
M.parse_response_without_stream = parse_response_wo_stream
|
||||
|
||||
-- Needs to be language specific for each function and methods.
|
||||
local get_function_name_under_cursor = function()
|
||||
local current_node = ts_utils.get_node_at_cursor()
|
||||
if not current_node then return "" end
|
||||
local expr = current_node
|
||||
|
||||
while expr do
|
||||
if expr:type() == "function_definition" or expr:type() == "method_declaration" then break end
|
||||
expr = expr:parent()
|
||||
end
|
||||
|
||||
if not expr then return "" end
|
||||
|
||||
local result = (ts_utils.get_node_text(expr:child(1)))[1]
|
||||
return result
|
||||
end
|
||||
|
||||
--- It takes in the provider options as the first argument, followed by code_opts retrieved from given buffer.
|
||||
---@type fun(command_name: string): nil
|
||||
M.method_command = function(command_name)
|
||||
if
|
||||
command_name ~= "document"
|
||||
and command_name ~= "unit-test"
|
||||
and command_name ~= "explain"
|
||||
and command_name:find("translate", 1, true) == 0
|
||||
then
|
||||
Utils.warn("Invalid command name" .. command_name)
|
||||
end
|
||||
|
||||
local current_buffer = vim.api.nvim_get_current_buf()
|
||||
local file_path = vim.api.nvim_buf_get_name(current_buffer)
|
||||
|
||||
-- Use file name for now. For proper extraction of method names, a lang specific TreeSitter querry is need
|
||||
-- local method_name = get_function_name_under_cursor()
|
||||
-- use whole file if we cannot get the method
|
||||
local method_name = ""
|
||||
if method_name == "" then
|
||||
local path_splits = vim.split(file_path, "/")
|
||||
method_name = path_splits[#path_splits]
|
||||
end
|
||||
|
||||
local sidebar = require("avante").get()
|
||||
if not sidebar then
|
||||
require("avante.api").ask()
|
||||
sidebar = require("avante").get()
|
||||
end
|
||||
if not sidebar:is_open() then sidebar:open({}) end
|
||||
sidebar.file_selector:add_current_buffer()
|
||||
|
||||
local response_content = ""
|
||||
local provider = P[Config.provider]
|
||||
local content = "/" .. command_name .. " @" .. method_name
|
||||
Llm.curl({
|
||||
provider = provider,
|
||||
prompt_opts = {
|
||||
system_prompt = "WCA_COMMAND",
|
||||
messages = {
|
||||
{ content = content, role = "user" },
|
||||
},
|
||||
selected_files = sidebar.file_selector:get_selected_files_contents(),
|
||||
},
|
||||
handler_opts = {
|
||||
on_start = function(_) end,
|
||||
on_chunk = function(chunk)
|
||||
if not chunk then return end
|
||||
response_content = response_content .. chunk
|
||||
end,
|
||||
on_stop = function(stop_opts)
|
||||
if stop_opts.error ~= nil then
|
||||
Utils.error(string.format("WCA Command " .. command_name .. " failed: %s", vim.inspect(stop_opts.error)))
|
||||
return
|
||||
end
|
||||
if stop_opts.reason == "complete" then
|
||||
if not sidebar:is_open() then sidebar:open({}) end
|
||||
sidebar:update_content(response_content, { focus = true })
|
||||
end
|
||||
end,
|
||||
},
|
||||
})
|
||||
end
|
||||
|
||||
local function get_iam_bearer_token(provider)
|
||||
if M.last_iam_token_time ~= nil and os.time() - M.last_iam_token_time <= 3550 then return M.iam_bearer_token end
|
||||
|
||||
local api_key = provider.parse_api_key()
|
||||
if api_key == nil then
|
||||
-- if no api key is available, make a request with a empty api key.
|
||||
api_key = ""
|
||||
end
|
||||
|
||||
local url = "https://iam.cloud.ibm.com/identity/token"
|
||||
local header = { ["Content-Type"] = "application/x-www-form-urlencoded" }
|
||||
local body = "grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey=" .. api_key
|
||||
|
||||
local response = curl.post(url, { headers = header, body = body })
|
||||
if response.status == 200 then
|
||||
-- select first key value pair
|
||||
local access_token_field = vim.split(response.body, ",")[1]
|
||||
-- get value
|
||||
local token = vim.split(access_token_field, ":")[2]
|
||||
-- remove quotes
|
||||
M.iam_bearer_token = (token:gsub("^%p(.*)%p$", "%1"))
|
||||
M.last_iam_token_time = os.time()
|
||||
else
|
||||
Utils.error(
|
||||
"Failed to retrieve IAM token: " .. response.status .. ": " .. vim.inspect(response.body),
|
||||
{ title = "Avante WCA" }
|
||||
)
|
||||
M.iam_bearer_token = ""
|
||||
end
|
||||
return M.iam_bearer_token
|
||||
end
|
||||
|
||||
local random = math.random
|
||||
math.randomseed(os.time())
|
||||
local function uuid()
|
||||
local template = "xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx"
|
||||
return string.gsub(template, "[xy]", function(c)
|
||||
local v = (c == "x") and random(0, 0xf) or random(8, 0xb)
|
||||
return string.format("%x", v)
|
||||
end)
|
||||
end
|
||||
|
||||
--- This function below will be used to parse in cURL arguments.
|
||||
--- It takes in the provider options as the first argument, followed by code_opts retrieved from given buffer.
|
||||
--- This code_opts include:
|
||||
--- - question: Input from the users
|
||||
--- - code_lang: the language of given code buffer
|
||||
--- - code_content: content of code buffer
|
||||
--- - selected_code_content: (optional) If given code content is selected in visual mode as context.
|
||||
---@type fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
|
||||
---@param provider AvanteProviderFunctor
|
||||
---@param code_opts AvantePromptOptions
|
||||
---@return table
|
||||
M.parse_curl_args = function(provider, code_opts)
|
||||
local base, _ = P.parse_config(provider)
|
||||
local headers = {
|
||||
["Content-Type"] = "multipart/form-data",
|
||||
["Authorization"] = "Bearer " .. get_iam_bearer_token(provider),
|
||||
["Request-ID"] = uuid(),
|
||||
}
|
||||
|
||||
-- Create the message_payload structure as required by WCA API
|
||||
local message_payload = {
|
||||
message_payload = {
|
||||
chat_session_id = uuid(), -- Required for granite-3-8b-instruct model
|
||||
messages = M:parse_messages(code_opts),
|
||||
},
|
||||
}
|
||||
|
||||
-- Base64 encode the message payload as required by watsonx API
|
||||
local json_content = vim.json.encode(message_payload)
|
||||
local encoded_json_content = vim.base64.encode(json_content)
|
||||
|
||||
-- Return form data structure - the message field contains the base64-encoded JSON
|
||||
local body = {
|
||||
message = encoded_json_content,
|
||||
}
|
||||
|
||||
return {
|
||||
url = base.endpoint,
|
||||
timeout = base.timeout,
|
||||
insecure = false,
|
||||
headers = headers,
|
||||
body = body,
|
||||
}
|
||||
end
|
||||
|
||||
--- The following function SHOULD only be used when providers doesn't follow SSE spec [ADVANCED]
|
||||
--- this is mutually exclusive with parse_response_data
|
||||
|
||||
return M
|
||||
@@ -2543,7 +2543,7 @@ function Sidebar:create_input_container()
|
||||
vim.keymap.del("n", "G", { buffer = self.containers.result.bufnr })
|
||||
end)
|
||||
|
||||
if stop_opts.error ~= nil then
|
||||
if stop_opts.error ~= nil and stop_opts.error ~= vim.NIL then
|
||||
local msg_content = stop_opts.error
|
||||
if type(msg_content) ~= "string" then msg_content = vim.inspect(msg_content) end
|
||||
self:add_history_messages({
|
||||
|
||||
69
tests/providers/watsonx_code_assistant_spec.lua
Normal file
69
tests/providers/watsonx_code_assistant_spec.lua
Normal file
@@ -0,0 +1,69 @@
|
||||
local busted = require("plenary.busted")
|
||||
|
||||
busted.describe("watsonx_code_assistant provider", function()
|
||||
local watsonx_provider
|
||||
|
||||
busted.before_each(function()
|
||||
-- Minimal setup without extensive mocking
|
||||
watsonx_provider = require("avante.providers.watsonx_code_assistant")
|
||||
end)
|
||||
|
||||
busted.describe("basic configuration", function()
|
||||
busted.it("should have required properties", function()
|
||||
assert.is_not_nil(watsonx_provider.api_key_name)
|
||||
assert.equals("WCA_API_KEY", watsonx_provider.api_key_name)
|
||||
assert.is_not_nil(watsonx_provider.role_map)
|
||||
assert.equals("USER", watsonx_provider.role_map.user)
|
||||
assert.equals("ASSISTANT", watsonx_provider.role_map.assistant)
|
||||
end)
|
||||
|
||||
busted.it("should disable streaming", function() assert.is_true(watsonx_provider:is_disable_stream()) end)
|
||||
|
||||
busted.it("should have required functions", function()
|
||||
assert.is_function(watsonx_provider.parse_messages)
|
||||
assert.is_function(watsonx_provider.parse_response_without_stream)
|
||||
assert.is_function(watsonx_provider.parse_curl_args)
|
||||
end)
|
||||
end)
|
||||
|
||||
busted.describe("parse_messages", function()
|
||||
busted.it("should parse messages with correct role mapping", function()
|
||||
---@type AvantePromptOptions
|
||||
local opts = {
|
||||
system_prompt = "You are a helpful assistant",
|
||||
messages = {
|
||||
{ content = "Hello", role = "user" },
|
||||
{ content = "Hi there", role = "assistant" },
|
||||
},
|
||||
}
|
||||
|
||||
local result = watsonx_provider:parse_messages(opts)
|
||||
|
||||
assert.is_table(result)
|
||||
assert.equals(3, #result) -- system + 2 messages
|
||||
assert.equals("SYSTEM", result[1].role)
|
||||
assert.equals("You are a helpful assistant", result[1].content)
|
||||
assert.equals("USER", result[2].role)
|
||||
assert.equals("Hello", result[2].content)
|
||||
assert.equals("ASSISTANT", result[3].role)
|
||||
assert.equals("Hi there", result[3].content)
|
||||
end)
|
||||
|
||||
busted.it("should handle WCA_COMMAND system prompt", function()
|
||||
---@type AvantePromptOptions
|
||||
local opts = {
|
||||
system_prompt = "WCA_COMMAND",
|
||||
messages = {
|
||||
{ content = "/document main.py", role = "user" },
|
||||
},
|
||||
}
|
||||
|
||||
local result = watsonx_provider:parse_messages(opts)
|
||||
|
||||
assert.is_table(result)
|
||||
assert.equals(1, #result) -- only user message, no system prompt
|
||||
assert.equals("USER", result[1].role)
|
||||
assert.equals("/document main.py", result[1].content)
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
Reference in New Issue
Block a user