* fix: upgrade vscode version * feat: support openai response api * refactor: refine todos tools * fix: trim suffix empty lines
286 lines
10 KiB
Lua
286 lines
10 KiB
Lua
-- Documentation for setting up IBM Watsonx Code Assistant
|
|
--- Generating an access token: https://www.ibm.com/products/watsonx-code-assistant or https://github.ibm.com/code-assistant/wca-api
|
|
local P = require("avante.providers")
|
|
local Utils = require("avante.utils")
|
|
local curl = require("plenary.curl")
|
|
local Config = require("avante.config")
|
|
local Llm = require("avante.llm")
|
|
local ts_utils = pcall(require, "nvim-treesitter.ts_utils") and require("nvim-treesitter.ts_utils")
|
|
or {
|
|
get_node_at_cursor = function() return nil end,
|
|
}
|
|
local OpenAI = require("avante.providers.openai")
|
|
|
|
---@class AvanteProviderFunctor
|
|
local M = {}
|
|
|
|
M.api_key_name = "WCA_API_KEY" -- The name of the environment variable that contains the API key
|
|
M.role_map = {
|
|
user = "USER",
|
|
assistant = "ASSISTANT",
|
|
system = "SYSTEM",
|
|
}
|
|
M.last_iam_token_time = nil
|
|
M.iam_bearer_token = ""
|
|
|
|
function M:is_disable_stream() return true end
|
|
|
|
---@type fun(self: AvanteProviderFunctor, opts: AvantePromptOptions): table
|
|
function M:parse_messages(opts)
|
|
if opts == nil then return {} end
|
|
local messages
|
|
if opts.system_prompt == "WCA_COMMAND" then
|
|
messages = {}
|
|
else
|
|
messages = {
|
|
{ content = opts.system_prompt, role = "SYSTEM" },
|
|
}
|
|
end
|
|
vim
|
|
.iter(opts.messages)
|
|
:each(function(msg) table.insert(messages, { content = msg.content, role = M.role_map[msg.role] }) end)
|
|
return messages
|
|
end
|
|
|
|
--- This function will be used to parse incoming SSE stream
|
|
--- It takes in the data stream as the first argument, followed by SSE event state, and opts
|
|
--- retrieved from given buffer.
|
|
--- This opts include:
|
|
--- - on_chunk: (fun(chunk: string): any) this is invoked on parsing correct delta chunk
|
|
--- - on_complete: (fun(err: string|nil): any) this is invoked on either complete call or error chunk
|
|
local function parse_response_wo_stream(self, data, _, opts)
|
|
if Utils.debug then Utils.debug("WCA parse_response_without_stream called with opts: " .. vim.inspect(opts)) end
|
|
|
|
local json = vim.json.decode(data)
|
|
if Utils.debug then Utils.debug("WCA Response: " .. vim.inspect(json)) end
|
|
if json.error ~= nil and json.error ~= vim.NIL then
|
|
Utils.warn("WCA Error " .. tostring(json.error.code) .. ": " .. tostring(json.error.message))
|
|
end
|
|
if json.response and json.response.message and json.response.message.content then
|
|
local content = json.response.message.content
|
|
|
|
if Utils.debug then Utils.debug("WCA Original Content: " .. tostring(content)) end
|
|
|
|
-- Clean up the content by removing XML-like tags that are not part of the actual response
|
|
-- These tags appear to be internal formatting from watsonx that should not be shown to users
|
|
-- Use more careful patterns to avoid removing too much content
|
|
content = content:gsub("<file>\n?", "")
|
|
content = content:gsub("\n?</file>", "")
|
|
content = content:gsub("\n?<memory>.-</memory>\n?", "")
|
|
content = content:gsub("\n?<write_todos>.-</write_todos>\n?", "")
|
|
content = content:gsub("\n?<attempt_completion>.-</attempt_completion>\n?", "")
|
|
|
|
-- Trim excessive whitespace but preserve structure
|
|
content = content:gsub("^\n+", ""):gsub("\n+$", "")
|
|
|
|
if Utils.debug then Utils.debug("WCA Cleaned Content: " .. tostring(content)) end
|
|
|
|
-- Ensure we still have content after cleaning
|
|
if content and content ~= "" then
|
|
if opts.on_chunk then opts.on_chunk(content) end
|
|
-- Add the text message for UI display (similar to OpenAI provider)
|
|
OpenAI:add_text_message({}, content, "generated", opts)
|
|
else
|
|
Utils.warn("WCA: Content became empty after cleaning")
|
|
if opts.on_chunk then
|
|
opts.on_chunk(json.response.message.content) -- Fallback to original content
|
|
end
|
|
-- Add the original content as fallback
|
|
OpenAI:add_text_message({}, json.response.message.content, "generated", opts)
|
|
end
|
|
vim.schedule(function()
|
|
if opts.on_stop then opts.on_stop({ reason = "complete" }) end
|
|
end)
|
|
elseif json.error and json.error ~= vim.NIL then
|
|
vim.schedule(function()
|
|
if opts.on_stop then
|
|
opts.on_stop({
|
|
reason = "error",
|
|
error = "WCA Error " .. tostring(json.error.code) .. ": " .. tostring(json.error.message),
|
|
})
|
|
end
|
|
end)
|
|
else
|
|
-- Handle case where there's no response content and no explicit error
|
|
if Utils.debug then Utils.debug("WCA: No content found in response, treating as empty response") end
|
|
vim.schedule(function()
|
|
if opts.on_stop then opts.on_stop({ reason = "complete" }) end
|
|
end)
|
|
end
|
|
end
|
|
|
|
M.parse_response_without_stream = parse_response_wo_stream
|
|
|
|
-- Needs to be language specific for each function and methods.
|
|
local get_function_name_under_cursor = function()
|
|
local current_node = ts_utils.get_node_at_cursor()
|
|
if not current_node then return "" end
|
|
local expr = current_node
|
|
|
|
while expr do
|
|
if expr:type() == "function_definition" or expr:type() == "method_declaration" then break end
|
|
expr = expr:parent()
|
|
end
|
|
|
|
if not expr then return "" end
|
|
|
|
local result = (ts_utils.get_node_text(expr:child(1)))[1]
|
|
return result
|
|
end
|
|
|
|
--- It takes in the provider options as the first argument, followed by code_opts retrieved from given buffer.
|
|
---@type fun(command_name: string): nil
|
|
M.method_command = function(command_name)
|
|
if
|
|
command_name ~= "document"
|
|
and command_name ~= "unit-test"
|
|
and command_name ~= "explain"
|
|
and command_name:find("translate", 1, true) == 0
|
|
then
|
|
Utils.warn("Invalid command name" .. command_name)
|
|
end
|
|
|
|
local current_buffer = vim.api.nvim_get_current_buf()
|
|
local file_path = vim.api.nvim_buf_get_name(current_buffer)
|
|
|
|
-- Use file name for now. For proper extraction of method names, a lang specific TreeSitter querry is need
|
|
-- local method_name = get_function_name_under_cursor()
|
|
-- use whole file if we cannot get the method
|
|
local method_name = ""
|
|
if method_name == "" then
|
|
local path_splits = vim.split(file_path, "/")
|
|
method_name = path_splits[#path_splits]
|
|
end
|
|
|
|
local sidebar = require("avante").get()
|
|
if not sidebar then
|
|
require("avante.api").ask()
|
|
sidebar = require("avante").get()
|
|
end
|
|
if not sidebar:is_open() then sidebar:open({}) end
|
|
sidebar.file_selector:add_current_buffer()
|
|
|
|
local response_content = ""
|
|
local provider = P[Config.provider]
|
|
local content = "/" .. command_name .. " @" .. method_name
|
|
Llm.curl({
|
|
provider = provider,
|
|
prompt_opts = {
|
|
system_prompt = "WCA_COMMAND",
|
|
messages = {
|
|
{ content = content, role = "user" },
|
|
},
|
|
selected_files = sidebar.file_selector:get_selected_files_contents(),
|
|
},
|
|
handler_opts = {
|
|
on_start = function(_) end,
|
|
on_chunk = function(chunk)
|
|
if not chunk then return end
|
|
response_content = response_content .. chunk
|
|
end,
|
|
on_stop = function(stop_opts)
|
|
if stop_opts.error ~= nil then
|
|
Utils.error(string.format("WCA Command " .. command_name .. " failed: %s", vim.inspect(stop_opts.error)))
|
|
return
|
|
end
|
|
if stop_opts.reason == "complete" then
|
|
if not sidebar:is_open() then sidebar:open({}) end
|
|
sidebar:update_content(response_content, { focus = true })
|
|
end
|
|
end,
|
|
},
|
|
})
|
|
end
|
|
|
|
local function get_iam_bearer_token(provider)
|
|
if M.last_iam_token_time ~= nil and os.time() - M.last_iam_token_time <= 3550 then return M.iam_bearer_token end
|
|
|
|
local api_key = provider.parse_api_key()
|
|
if api_key == nil then
|
|
-- if no api key is available, make a request with a empty api key.
|
|
api_key = ""
|
|
end
|
|
|
|
local url = "https://iam.cloud.ibm.com/identity/token"
|
|
local header = { ["Content-Type"] = "application/x-www-form-urlencoded" }
|
|
local body = "grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey=" .. api_key
|
|
|
|
local response = curl.post(url, { headers = header, body = body })
|
|
if response.status == 200 then
|
|
-- select first key value pair
|
|
local access_token_field = vim.split(response.body, ",")[1]
|
|
-- get value
|
|
local token = vim.split(access_token_field, ":")[2]
|
|
-- remove quotes
|
|
M.iam_bearer_token = (token:gsub("^%p(.*)%p$", "%1"))
|
|
M.last_iam_token_time = os.time()
|
|
else
|
|
Utils.error(
|
|
"Failed to retrieve IAM token: " .. response.status .. ": " .. vim.inspect(response.body),
|
|
{ title = "Avante WCA" }
|
|
)
|
|
M.iam_bearer_token = ""
|
|
end
|
|
return M.iam_bearer_token
|
|
end
|
|
|
|
local random = math.random
|
|
math.randomseed(os.time())
|
|
local function uuid()
|
|
local template = "xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx"
|
|
return string.gsub(template, "[xy]", function(c)
|
|
local v = (c == "x") and random(0, 0xf) or random(8, 0xb)
|
|
return string.format("%x", v)
|
|
end)
|
|
end
|
|
|
|
--- This function below will be used to parse in cURL arguments.
|
|
--- It takes in the provider options as the first argument, followed by code_opts retrieved from given buffer.
|
|
--- This code_opts include:
|
|
--- - question: Input from the users
|
|
--- - code_lang: the language of given code buffer
|
|
--- - code_content: content of code buffer
|
|
--- - selected_code_content: (optional) If given code content is selected in visual mode as context.
|
|
---@type fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
|
|
---@param provider AvanteProviderFunctor
|
|
---@param code_opts AvantePromptOptions
|
|
---@return table
|
|
M.parse_curl_args = function(provider, code_opts)
|
|
local base, _ = P.parse_config(provider)
|
|
local headers = {
|
|
["Content-Type"] = "multipart/form-data",
|
|
["Authorization"] = "Bearer " .. get_iam_bearer_token(provider),
|
|
["Request-ID"] = uuid(),
|
|
}
|
|
|
|
-- Create the message_payload structure as required by WCA API
|
|
local message_payload = {
|
|
message_payload = {
|
|
chat_session_id = uuid(), -- Required for granite-3-8b-instruct model
|
|
messages = M:parse_messages(code_opts),
|
|
},
|
|
}
|
|
|
|
-- Base64 encode the message payload as required by watsonx API
|
|
local json_content = vim.json.encode(message_payload)
|
|
local encoded_json_content = vim.base64.encode(json_content)
|
|
|
|
-- Return form data structure - the message field contains the base64-encoded JSON
|
|
local body = {
|
|
message = encoded_json_content,
|
|
}
|
|
|
|
return {
|
|
url = base.endpoint,
|
|
timeout = base.timeout,
|
|
insecure = false,
|
|
headers = headers,
|
|
body = body,
|
|
}
|
|
end
|
|
|
|
--- The following function SHOULD only be used when providers doesn't follow SSE spec [ADVANCED]
|
|
--- this is mutually exclusive with parse_response_data
|
|
|
|
return M
|