fix: anthropic rate limit (#1583)
This commit is contained in:
@@ -225,6 +225,20 @@ function M.calculate_tokens(opts)
|
|||||||
return tokens
|
return tokens
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local parse_headers = function(headers_file)
|
||||||
|
local headers = {}
|
||||||
|
local file = io.open(headers_file, "r")
|
||||||
|
if file then
|
||||||
|
for line in file:lines() do
|
||||||
|
line = line:gsub("\r$", "")
|
||||||
|
local key, value = line:match("^%s*(.-)%s*:%s*(.*)$")
|
||||||
|
if key and value then headers[key] = value end
|
||||||
|
end
|
||||||
|
file:close()
|
||||||
|
end
|
||||||
|
return headers
|
||||||
|
end
|
||||||
|
|
||||||
---@param opts avante.CurlOpts
|
---@param opts avante.CurlOpts
|
||||||
function M.curl(opts)
|
function M.curl(opts)
|
||||||
local provider = opts.provider
|
local provider = opts.provider
|
||||||
@@ -257,24 +271,39 @@ function M.curl(opts)
|
|||||||
|
|
||||||
local active_job
|
local active_job
|
||||||
|
|
||||||
local curl_body_file = fn.tempname() .. ".json"
|
local temp_file = fn.tempname()
|
||||||
|
local curl_body_file = temp_file .. "-request-body.json"
|
||||||
local json_content = vim.json.encode(spec.body)
|
local json_content = vim.json.encode(spec.body)
|
||||||
fn.writefile(vim.split(json_content, "\n"), curl_body_file)
|
fn.writefile(vim.split(json_content, "\n"), curl_body_file)
|
||||||
|
|
||||||
Utils.debug("curl body file:", curl_body_file)
|
Utils.debug("curl body file:", curl_body_file)
|
||||||
|
|
||||||
|
local headers_file = temp_file .. "-headers.txt"
|
||||||
|
|
||||||
|
Utils.debug("curl headers file:", headers_file)
|
||||||
|
|
||||||
local function cleanup()
|
local function cleanup()
|
||||||
if Config.debug then return end
|
if Config.debug then return end
|
||||||
vim.schedule(function() fn.delete(curl_body_file) end)
|
vim.schedule(function()
|
||||||
|
fn.delete(curl_body_file)
|
||||||
|
fn.delete(headers_file)
|
||||||
|
end)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local headers_reported = false
|
||||||
|
|
||||||
active_job = curl.post(spec.url, {
|
active_job = curl.post(spec.url, {
|
||||||
headers = spec.headers,
|
headers = spec.headers,
|
||||||
proxy = spec.proxy,
|
proxy = spec.proxy,
|
||||||
insecure = spec.insecure,
|
insecure = spec.insecure,
|
||||||
body = curl_body_file,
|
body = curl_body_file,
|
||||||
raw = spec.rawArgs,
|
raw = spec.rawArgs,
|
||||||
|
dump = { "-D", headers_file },
|
||||||
stream = function(err, data, _)
|
stream = function(err, data, _)
|
||||||
|
if not headers_reported and opts.on_response_headers then
|
||||||
|
headers_reported = true
|
||||||
|
opts.on_response_headers(parse_headers(headers_file))
|
||||||
|
end
|
||||||
if err then
|
if err then
|
||||||
completed = true
|
completed = true
|
||||||
handler_opts.on_stop({ reason = "error", error = err })
|
handler_opts.on_stop({ reason = "error", error = err })
|
||||||
@@ -393,6 +422,8 @@ function M._stream(opts)
|
|||||||
|
|
||||||
local prompt_opts = M.generate_prompts(opts)
|
local prompt_opts = M.generate_prompts(opts)
|
||||||
|
|
||||||
|
local resp_headers = {}
|
||||||
|
|
||||||
---@type AvanteHandlerOptions
|
---@type AvanteHandlerOptions
|
||||||
local handler_opts = {
|
local handler_opts = {
|
||||||
on_start = opts.on_start,
|
on_start = opts.on_start,
|
||||||
@@ -406,7 +437,16 @@ function M._stream(opts)
|
|||||||
local new_opts = vim.tbl_deep_extend("force", opts, {
|
local new_opts = vim.tbl_deep_extend("force", opts, {
|
||||||
tool_histories = tool_histories,
|
tool_histories = tool_histories,
|
||||||
})
|
})
|
||||||
return M._stream(new_opts)
|
if provider.get_rate_limit_sleep_time then
|
||||||
|
local sleep_time = provider:get_rate_limit_sleep_time(resp_headers)
|
||||||
|
if sleep_time and sleep_time > 0 then
|
||||||
|
Utils.info("Rate limit reached. Sleeping for " .. sleep_time .. " seconds ...")
|
||||||
|
vim.defer_fn(function() M._stream(new_opts) end, sleep_time * 1000)
|
||||||
|
return
|
||||||
|
end
|
||||||
|
end
|
||||||
|
M._stream(new_opts)
|
||||||
|
return
|
||||||
end
|
end
|
||||||
local tool_use = tool_use_list[tool_use_index]
|
local tool_use = tool_use_list[tool_use_index]
|
||||||
---@param result string | nil
|
---@param result string | nil
|
||||||
@@ -467,6 +507,7 @@ function M._stream(opts)
|
|||||||
provider = provider,
|
provider = provider,
|
||||||
prompt_opts = prompt_opts,
|
prompt_opts = prompt_opts,
|
||||||
handler_opts = handler_opts,
|
handler_opts = handler_opts,
|
||||||
|
on_response_headers = function(headers) resp_headers = headers end,
|
||||||
})
|
})
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,23 @@ M.role_map = {
|
|||||||
assistant = "assistant",
|
assistant = "assistant",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
---@param headers table<string, string>
|
||||||
|
---@return integer|nil
|
||||||
|
function M:get_rate_limit_sleep_time(headers)
|
||||||
|
local remaining_tokens = tonumber(headers["anthropic-ratelimit-tokens-remaining"])
|
||||||
|
if remaining_tokens == nil then return end
|
||||||
|
if remaining_tokens > 10000 then return end
|
||||||
|
local reset_dt_str = headers["anthropic-ratelimit-tokens-reset"]
|
||||||
|
if remaining_tokens ~= 0 then reset_dt_str = reset_dt_str or headers["anthropic-ratelimit-requests-reset"] end
|
||||||
|
local reset_dt, err = Utils.parse_iso8601_date(reset_dt_str)
|
||||||
|
if err then
|
||||||
|
Utils.warn(err)
|
||||||
|
return
|
||||||
|
end
|
||||||
|
local now = Utils.utc_now()
|
||||||
|
return Utils.datetime_diff(tostring(now), tostring(reset_dt))
|
||||||
|
end
|
||||||
|
|
||||||
---@param tool AvanteLLMTool
|
---@param tool AvanteLLMTool
|
||||||
---@return AvanteClaudeTool
|
---@return AvanteClaudeTool
|
||||||
function M.transform_tool(tool)
|
function M.transform_tool(tool)
|
||||||
|
|||||||
@@ -244,6 +244,7 @@ vim.g.avante_login = vim.g.avante_login
|
|||||||
---@field usage? AvanteLLMUsage
|
---@field usage? AvanteLLMUsage
|
||||||
---@field tool_use_list? AvanteLLMToolUse[]
|
---@field tool_use_list? AvanteLLMToolUse[]
|
||||||
---@field retry_after? integer
|
---@field retry_after? integer
|
||||||
|
---@field headers? table<string, string>
|
||||||
---
|
---
|
||||||
---@alias AvanteStreamParser fun(self: AvanteProviderFunctor, ctx: any, line: string, handler_opts: AvanteHandlerOptions): nil
|
---@alias AvanteStreamParser fun(self: AvanteProviderFunctor, ctx: any, line: string, handler_opts: AvanteHandlerOptions): nil
|
||||||
---@alias AvanteLLMStartCallback fun(opts: AvanteLLMStartCallbackOptions): nil
|
---@alias AvanteLLMStartCallback fun(opts: AvanteLLMStartCallbackOptions): nil
|
||||||
@@ -272,6 +273,7 @@ vim.g.avante_login = vim.g.avante_login
|
|||||||
---@field parse_stream_data? AvanteStreamParser
|
---@field parse_stream_data? AvanteStreamParser
|
||||||
---@field on_error? fun(result: table<string, any>): nil
|
---@field on_error? fun(result: table<string, any>): nil
|
||||||
---@field transform_tool? fun(tool: AvanteLLMTool): AvanteOpenAITool | AvanteClaudeTool
|
---@field transform_tool? fun(tool: AvanteLLMTool): AvanteOpenAITool | AvanteClaudeTool
|
||||||
|
---@field get_rate_limit_sleep_time? fun(self: AvanteProviderFunctor, headers: table<string, string>): integer | nil
|
||||||
---
|
---
|
||||||
---@alias AvanteBedrockPayloadBuilder fun(self: AvanteBedrockModelHandler | AvanteBedrockProviderFunctor, prompt_opts: AvantePromptOptions, request_body: table<string, any>): table<string, any>
|
---@alias AvanteBedrockPayloadBuilder fun(self: AvanteBedrockModelHandler | AvanteBedrockProviderFunctor, prompt_opts: AvantePromptOptions, request_body: table<string, any>): table<string, any>
|
||||||
---
|
---
|
||||||
@@ -387,4 +389,5 @@ vim.g.avante_login = vim.g.avante_login
|
|||||||
---@field provider AvanteProviderFunctor
|
---@field provider AvanteProviderFunctor
|
||||||
---@field prompt_opts AvantePromptOptions
|
---@field prompt_opts AvantePromptOptions
|
||||||
---@field handler_opts AvanteHandlerOptions
|
---@field handler_opts AvanteHandlerOptions
|
||||||
|
---@field on_response_headers? fun(headers: table<string, string>): nil
|
||||||
---
|
---
|
||||||
|
|||||||
@@ -1119,4 +1119,48 @@ function M.deep_extend_with_metatable(behavior, ...)
|
|||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function M.utc_now()
|
||||||
|
local utc_date = os.date("!*t")
|
||||||
|
---@diagnostic disable-next-line: param-type-mismatch
|
||||||
|
local utc_time = os.time(utc_date)
|
||||||
|
return os.date("%Y-%m-%d %H:%M:%S", utc_time)
|
||||||
|
end
|
||||||
|
|
||||||
|
---@param dt1 string
|
||||||
|
---@param dt2 string
|
||||||
|
---@return integer delta_seconds
|
||||||
|
function M.datetime_diff(dt1, dt2)
|
||||||
|
local pattern = "(%d+)-(%d+)-(%d+) (%d+):(%d+):(%d+)"
|
||||||
|
local y1, m1, d1, h1, min1, s1 = dt1:match(pattern)
|
||||||
|
local y2, m2, d2, h2, min2, s2 = dt2:match(pattern)
|
||||||
|
|
||||||
|
local time1 = os.time({ year = y1, month = m1, day = d1, hour = h1, min = min1, sec = s1 })
|
||||||
|
local time2 = os.time({ year = y2, month = m2, day = d2, hour = h2, min = min2, sec = s2 })
|
||||||
|
|
||||||
|
local delta_seconds = os.difftime(time2, time1)
|
||||||
|
return delta_seconds
|
||||||
|
end
|
||||||
|
|
||||||
|
---@param iso_str string
|
||||||
|
---@return string|nil
|
||||||
|
---@return string|nil error
|
||||||
|
function M.parse_iso8601_date(iso_str)
|
||||||
|
local year, month, day, hour, min, sec = iso_str:match("(%d+)-(%d+)-(%d+)T(%d+):(%d+):(%d+)Z")
|
||||||
|
if not year then return nil, "Invalid ISO 8601 format" end
|
||||||
|
|
||||||
|
local time_table = {
|
||||||
|
year = tonumber(year),
|
||||||
|
month = tonumber(month),
|
||||||
|
day = tonumber(day),
|
||||||
|
hour = tonumber(hour),
|
||||||
|
min = tonumber(min),
|
||||||
|
sec = tonumber(sec),
|
||||||
|
isdst = false,
|
||||||
|
}
|
||||||
|
|
||||||
|
local timestamp = os.time(time_table)
|
||||||
|
|
||||||
|
return tostring(os.date("%Y-%m-%d %H:%M:%S", timestamp)), nil
|
||||||
|
end
|
||||||
|
|
||||||
return M
|
return M
|
||||||
|
|||||||
Reference in New Issue
Block a user