chore: prefer not to use function assignment (#1381)

This commit is contained in:
Peter Cardenas
2025-02-24 20:08:03 -08:00
committed by GitHub
parent a1d1697400
commit afa674c6fd
29 changed files with 571 additions and 147 deletions

View File

@@ -17,7 +17,7 @@ M.parse_messages = O.parse_messages
M.parse_response = O.parse_response
M.parse_response_without_stream = O.parse_response_without_stream
M.parse_curl_args = function(provider, prompt_opts)
function M.parse_curl_args(provider, prompt_opts)
local provider_conf, request_body = P.parse_config(provider)
local headers = {

View File

@@ -7,7 +7,7 @@ local M = {}
M.api_key_name = "BEDROCK_KEYS"
M.use_xml_format = true
M.load_model_handler = function()
function M.load_model_handler()
local provider_conf, _ = P.parse_config(P["bedrock"])
local bedrock_model = provider_conf.model
if provider_conf.model:match("anthropic") then bedrock_model = "claude" end
@@ -18,17 +18,17 @@ M.load_model_handler = function()
error(error_msg)
end
M.parse_response = function(ctx, data_stream, event_state, opts)
function M.parse_response(ctx, data_stream, event_state, opts)
local model_handler = M.load_model_handler()
return model_handler.parse_response(ctx, data_stream, event_state, opts)
end
M.build_bedrock_payload = function(prompt_opts, body_opts)
function M.build_bedrock_payload(prompt_opts, body_opts)
local model_handler = M.load_model_handler()
return model_handler.build_bedrock_payload(prompt_opts, body_opts)
end
M.parse_stream_data = function(data, opts)
function M.parse_stream_data(data, opts)
-- @NOTE: Decode and process Bedrock response
-- Each response contains a Base64-encoded `bytes` field, which is decoded into JSON.
-- The `type` field in the decoded JSON determines how the response is handled.
@@ -44,7 +44,7 @@ end
---@param provider AvanteBedrockProviderFunctor
---@param prompt_opts AvantePromptOptions
---@return table
M.parse_curl_args = function(provider, prompt_opts)
function M.parse_curl_args(provider, prompt_opts)
local base, body_opts = P.parse_config(provider)
local api_key = provider.parse_api_key()
@@ -86,7 +86,7 @@ M.parse_curl_args = function(provider, prompt_opts)
}
end
M.on_error = function(result)
function M.on_error(result)
if not result.body then
return Utils.error("API request failed with status " .. result.status, { once = true, title = "Avante" })
end

View File

@@ -16,7 +16,7 @@ M.role_map = {
assistant = "assistant",
}
M.parse_messages = function(opts)
function M.parse_messages(opts)
---@type AvanteBedrockClaudeMessage[]
local messages = {}
@@ -78,7 +78,7 @@ M.parse_response = Claude.parse_response
---@param prompt_opts AvantePromptOptions
---@param body_opts table
---@return table
M.build_bedrock_payload = function(prompt_opts, body_opts)
function M.build_bedrock_payload(prompt_opts, body_opts)
local system_prompt = prompt_opts.system_prompt or ""
local messages = M.parse_messages(prompt_opts)
local max_tokens = body_opts.max_tokens or 2000

View File

@@ -36,7 +36,7 @@ M.role_map = {
assistant = "assistant",
}
M.parse_messages = function(opts)
function M.parse_messages(opts)
---@type AvanteClaudeMessage[]
local messages = {}
@@ -123,7 +123,7 @@ M.parse_messages = function(opts)
return messages
end
M.parse_response = function(ctx, data_stream, event_state, opts)
function M.parse_response(ctx, data_stream, event_state, opts)
if event_state == nil then
if data_stream:match('"message_start"') then
event_state = "message_start"
@@ -214,7 +214,7 @@ end
---@param provider AvanteProviderFunctor
---@param prompt_opts AvantePromptOptions
---@return table
M.parse_curl_args = function(provider, prompt_opts)
function M.parse_curl_args(provider, prompt_opts)
local provider_conf, request_body = P.parse_config(provider)
local disable_tools = provider_conf.disable_tools or false
@@ -256,7 +256,7 @@ M.parse_curl_args = function(provider, prompt_opts)
}
end
M.on_error = function(result)
function M.on_error(result)
if not result.body then
return Utils.error("API request failed with status " .. result.status, { once = true, title = "Avante" })
end

View File

@@ -47,7 +47,7 @@ M.role_map = {
assistant = "assistant",
}
M.parse_messages = function(opts)
function M.parse_messages(opts)
local messages = {
{ role = "system", content = opts.system_prompt },
}
@@ -57,7 +57,7 @@ M.parse_messages = function(opts)
return { messages = messages }
end
M.parse_stream_data = function(data, opts)
function M.parse_stream_data(data, opts)
---@type CohereChatResponse
local json = vim.json.decode(data)
if json.type ~= nil then
@@ -69,7 +69,7 @@ M.parse_stream_data = function(data, opts)
end
end
M.parse_curl_args = function(provider, prompt_opts)
function M.parse_curl_args(provider, prompt_opts)
local provider_conf, request_body = P.parse_config(provider)
local headers = {
@@ -96,7 +96,7 @@ M.parse_curl_args = function(provider, prompt_opts)
}
end
M.setup = function()
function M.setup()
P.env.parse_envvar(M)
require("avante.tokenizers").setup(M.tokenizer_id, false)
vim.g.avante_login = true

View File

@@ -98,7 +98,7 @@ end
---@field oauth_token string
---
---@return string
H.get_oauth_token = function()
function H.get_oauth_token()
local xdg_config = vim.fn.expand("$XDG_CONFIG_HOME")
local os_name = Utils.get_os_name()
---@type string
@@ -138,9 +138,9 @@ H.get_oauth_token = function()
end
H.chat_auth_url = "https://api.github.com/copilot_internal/v2/token"
H.chat_completion_url = function(base_url) return Utils.url_join(base_url, "/chat/completions") end
function H.chat_completion_url(base_url) return Utils.url_join(base_url, "/chat/completions") end
H.refresh_token = function(async, force)
function H.refresh_token(async, force)
if not M.state then error("internal initialization error") end
async = async == nil and true or async
@@ -166,7 +166,7 @@ H.refresh_token = function(async, force)
insecure = Config.copilot.allow_insecure,
}
local handle_response = function(response)
local function handle_response(response)
if response.status == 200 then
M.state.github_token = vim.json.decode(response.body)
local file = Path:new(copilot_path)
@@ -209,7 +209,7 @@ M.role_map = {
assistant = "assistant",
}
M.parse_messages = function(opts)
function M.parse_messages(opts)
local messages = {
{ role = "system", content = opts.system_prompt },
}
@@ -245,7 +245,7 @@ end
M.parse_response = OpenAI.parse_response
M.parse_curl_args = function(provider, prompt_opts)
function M.parse_curl_args(provider, prompt_opts)
-- refresh token synchronously, only if it has expired
-- (this should rarely happen, as we refresh the token in the background)
H.refresh_token(false, false)
@@ -282,7 +282,7 @@ end
M._refresh_timer = nil
M.setup_timer = function()
function M.setup_timer()
if M._refresh_timer then
M._refresh_timer:stop()
M._refresh_timer:close()
@@ -305,7 +305,7 @@ M.setup_timer = function()
)
end
M.setup_file_watcher = function()
function M.setup_file_watcher()
if M._file_watcher then return end
local copilot_token_file = Path:new(copilot_path)
@@ -321,7 +321,7 @@ M.setup_file_watcher = function()
)
end
M.setup = function()
function M.setup()
local copilot_token_file = Path:new(copilot_path)
if not M.state then M.state = {
@@ -351,7 +351,7 @@ M.setup = function()
vim.g.avante_login = true
end
M.cleanup = function()
function M.cleanup()
-- Cleanup refresh timer
if M._refresh_timer then
M._refresh_timer:stop()

View File

@@ -12,7 +12,7 @@ M.role_map = {
}
-- M.tokenizer_id = "google/gemma-2b"
M.parse_messages = function(opts)
function M.parse_messages(opts)
local contents = {}
local prev_role = nil
@@ -64,7 +64,7 @@ M.parse_messages = function(opts)
}
end
M.parse_response = function(ctx, data_stream, _, opts)
function M.parse_response(ctx, data_stream, _, opts)
local ok, json = pcall(vim.json.decode, data_stream)
if not ok then opts.on_stop({ reason = "error", error = json }) end
if json.candidates then
@@ -81,7 +81,7 @@ M.parse_response = function(ctx, data_stream, _, opts)
end
end
M.parse_curl_args = function(provider, prompt_opts)
function M.parse_curl_args(provider, prompt_opts)
local provider_conf, request_body = P.parse_config(provider)
request_body = vim.tbl_deep_extend("force", request_body, {

View File

@@ -29,7 +29,7 @@ E.cache = {}
---@param Opts AvanteSupportedProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor
---@return string | nil
E.parse_envvar = function(Opts)
function E.parse_envvar(Opts)
local api_key_name = Opts.api_key_name
if api_key_name == nil then error("Requires api_key_name") end
@@ -91,7 +91,7 @@ end
--- This will only run once and spawn a UI for users to input the envvar.
---@param opts {refresh: boolean, provider: AvanteProviderFunctor | AvanteBedrockProviderFunctor}
---@private
E.setup = function(opts)
function E.setup(opts)
opts.provider.setup()
local var = opts.provider.api_key_name
@@ -180,7 +180,7 @@ end
E.REQUEST_LOGIN_PATTERN = "AvanteRequestLogin"
---@param provider AvanteDefaultBaseProvider
E.require_api_key = function(provider)
function E.require_api_key(provider)
if provider["local"] ~= nil then
if provider["local"] then
vim.deprecate('"local" = true', "api_key_name = ''", "0.1.0", "avante.nvim")
@@ -239,7 +239,7 @@ M = setmetatable(M, {
end,
})
M.setup = function()
function M.setup()
vim.g.avante_login = false
---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor
@@ -274,7 +274,7 @@ end
---@param opts AvanteProvider | AvanteSupportedProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor
---@return AvanteDefaultBaseProvider provider_opts
---@return table<string, any> request_body
M.parse_config = function(opts)
function M.parse_config(opts)
---@type AvanteDefaultBaseProvider
local provider_opts = {}
---@type table<string, any>
@@ -302,7 +302,7 @@ end
---@private
---@param provider Provider
---@return AvanteProviderFunctor | AvanteBedrockProviderFunctor
M.get_config = function(provider)
function M.get_config(provider)
provider = provider or Config.provider
local cur = Config.get_provider(provider)
return type(cur) == "function" and cur() or cur

View File

@@ -43,10 +43,10 @@ function M.transform_tool(tool)
return res
end
M.is_openrouter = function(url) return url:match("^https://openrouter%.ai/") end
function M.is_openrouter(url) return url:match("^https://openrouter%.ai/") end
---@param opts AvantePromptOptions
M.get_user_message = function(opts)
function M.get_user_message(opts)
vim.deprecate("get_user_message", "parse_messages", "0.1.0", "avante.nvim")
return table.concat(
vim
@@ -61,9 +61,9 @@ M.get_user_message = function(opts)
)
end
M.is_o_series_model = function(model) return model and string.match(model, "^o%d+") ~= nil end
function M.is_o_series_model(model) return model and string.match(model, "^o%d+") ~= nil end
M.parse_messages = function(opts)
function M.parse_messages(opts)
local messages = {}
local provider = P[Config.provider]
local base, _ = P.parse_config(provider)
@@ -137,7 +137,7 @@ M.parse_messages = function(opts)
return final_messages
end
M.parse_response = function(ctx, data_stream, _, opts)
function M.parse_response(ctx, data_stream, _, opts)
if data_stream:match('"%[DONE%]":') then
opts.on_stop({ reason = "complete" })
return
@@ -205,7 +205,7 @@ M.parse_response = function(ctx, data_stream, _, opts)
end
end
M.parse_response_without_stream = function(data, _, opts)
function M.parse_response_without_stream(data, _, opts)
---@type AvanteOpenAIChatResponse
local json = vim.json.decode(data)
if json.choices and json.choices[1] then
@@ -217,7 +217,7 @@ M.parse_response_without_stream = function(data, _, opts)
end
end
M.parse_curl_args = function(provider, prompt_opts)
function M.parse_curl_args(provider, prompt_opts)
local provider_conf, request_body = P.parse_config(provider)
local disable_tools = provider_conf.disable_tools or false

View File

@@ -22,7 +22,7 @@ local function execute_command(command)
return result:match("^%s*(.-)%s*$")
end
M.parse_api_key = function()
function M.parse_api_key()
if not M.api_key_name:match("^cmd:") then
error("Invalid api_key_name: Expected 'cmd:<command>' format, got '" .. M.api_key_name .. "'")
end
@@ -31,7 +31,7 @@ M.parse_api_key = function()
return direct_output
end
M.parse_curl_args = function(provider, prompt_opts)
function M.parse_curl_args(provider, prompt_opts)
local provider_conf, request_body = P.parse_config(provider)
local location = vim.fn.getenv("LOCATION") or "default-location"
local project_id = vim.fn.getenv("PROJECT_ID") or "default-project-id"