From e332d74c064f2cdcebbe9914d88300cfd3a74ea8 Mon Sep 17 00:00:00 2001 From: yetone Date: Fri, 14 Mar 2025 14:13:47 +0800 Subject: [PATCH] fix: anthropic rate limit (#1583) --- lua/avante/llm.lua | 47 ++++++++++++++++++++++++++++++--- lua/avante/providers/claude.lua | 17 ++++++++++++ lua/avante/types.lua | 3 +++ lua/avante/utils/init.lua | 44 ++++++++++++++++++++++++++++++ 4 files changed, 108 insertions(+), 3 deletions(-) diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 1c1e19f..1a8cda7 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -225,6 +225,20 @@ function M.calculate_tokens(opts) return tokens end +local parse_headers = function(headers_file) + local headers = {} + local file = io.open(headers_file, "r") + if file then + for line in file:lines() do + line = line:gsub("\r$", "") + local key, value = line:match("^%s*(.-)%s*:%s*(.*)$") + if key and value then headers[key] = value end + end + file:close() + end + return headers +end + ---@param opts avante.CurlOpts function M.curl(opts) local provider = opts.provider @@ -257,24 +271,39 @@ function M.curl(opts) local active_job - local curl_body_file = fn.tempname() .. ".json" + local temp_file = fn.tempname() + local curl_body_file = temp_file .. "-request-body.json" local json_content = vim.json.encode(spec.body) fn.writefile(vim.split(json_content, "\n"), curl_body_file) Utils.debug("curl body file:", curl_body_file) + local headers_file = temp_file .. "-headers.txt" + + Utils.debug("curl headers file:", headers_file) + local function cleanup() if Config.debug then return end - vim.schedule(function() fn.delete(curl_body_file) end) + vim.schedule(function() + fn.delete(curl_body_file) + fn.delete(headers_file) + end) end + local headers_reported = false + active_job = curl.post(spec.url, { headers = spec.headers, proxy = spec.proxy, insecure = spec.insecure, body = curl_body_file, raw = spec.rawArgs, + dump = { "-D", headers_file }, stream = function(err, data, _) + if not headers_reported and opts.on_response_headers then + headers_reported = true + opts.on_response_headers(parse_headers(headers_file)) + end if err then completed = true handler_opts.on_stop({ reason = "error", error = err }) @@ -393,6 +422,8 @@ function M._stream(opts) local prompt_opts = M.generate_prompts(opts) + local resp_headers = {} + ---@type AvanteHandlerOptions local handler_opts = { on_start = opts.on_start, @@ -406,7 +437,16 @@ function M._stream(opts) local new_opts = vim.tbl_deep_extend("force", opts, { tool_histories = tool_histories, }) - return M._stream(new_opts) + if provider.get_rate_limit_sleep_time then + local sleep_time = provider:get_rate_limit_sleep_time(resp_headers) + if sleep_time and sleep_time > 0 then + Utils.info("Rate limit reached. Sleeping for " .. sleep_time .. " seconds ...") + vim.defer_fn(function() M._stream(new_opts) end, sleep_time * 1000) + return + end + end + M._stream(new_opts) + return end local tool_use = tool_use_list[tool_use_index] ---@param result string | nil @@ -467,6 +507,7 @@ function M._stream(opts) provider = provider, prompt_opts = prompt_opts, handler_opts = handler_opts, + on_response_headers = function(headers) resp_headers = headers end, }) end diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index d70a95b..b5ce6a3 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -13,6 +13,23 @@ M.role_map = { assistant = "assistant", } +---@param headers table +---@return integer|nil +function M:get_rate_limit_sleep_time(headers) + local remaining_tokens = tonumber(headers["anthropic-ratelimit-tokens-remaining"]) + if remaining_tokens == nil then return end + if remaining_tokens > 10000 then return end + local reset_dt_str = headers["anthropic-ratelimit-tokens-reset"] + if remaining_tokens ~= 0 then reset_dt_str = reset_dt_str or headers["anthropic-ratelimit-requests-reset"] end + local reset_dt, err = Utils.parse_iso8601_date(reset_dt_str) + if err then + Utils.warn(err) + return + end + local now = Utils.utc_now() + return Utils.datetime_diff(tostring(now), tostring(reset_dt)) +end + ---@param tool AvanteLLMTool ---@return AvanteClaudeTool function M.transform_tool(tool) diff --git a/lua/avante/types.lua b/lua/avante/types.lua index e4b1dda..74518a5 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -244,6 +244,7 @@ vim.g.avante_login = vim.g.avante_login ---@field usage? AvanteLLMUsage ---@field tool_use_list? AvanteLLMToolUse[] ---@field retry_after? integer +---@field headers? table --- ---@alias AvanteStreamParser fun(self: AvanteProviderFunctor, ctx: any, line: string, handler_opts: AvanteHandlerOptions): nil ---@alias AvanteLLMStartCallback fun(opts: AvanteLLMStartCallbackOptions): nil @@ -272,6 +273,7 @@ vim.g.avante_login = vim.g.avante_login ---@field parse_stream_data? AvanteStreamParser ---@field on_error? fun(result: table): nil ---@field transform_tool? fun(tool: AvanteLLMTool): AvanteOpenAITool | AvanteClaudeTool +---@field get_rate_limit_sleep_time? fun(self: AvanteProviderFunctor, headers: table): integer | nil --- ---@alias AvanteBedrockPayloadBuilder fun(self: AvanteBedrockModelHandler | AvanteBedrockProviderFunctor, prompt_opts: AvantePromptOptions, request_body: table): table --- @@ -387,4 +389,5 @@ vim.g.avante_login = vim.g.avante_login ---@field provider AvanteProviderFunctor ---@field prompt_opts AvantePromptOptions ---@field handler_opts AvanteHandlerOptions +---@field on_response_headers? fun(headers: table): nil --- diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua index 5fc088a..8dc962c 100644 --- a/lua/avante/utils/init.lua +++ b/lua/avante/utils/init.lua @@ -1119,4 +1119,48 @@ function M.deep_extend_with_metatable(behavior, ...) return result end +function M.utc_now() + local utc_date = os.date("!*t") + ---@diagnostic disable-next-line: param-type-mismatch + local utc_time = os.time(utc_date) + return os.date("%Y-%m-%d %H:%M:%S", utc_time) +end + +---@param dt1 string +---@param dt2 string +---@return integer delta_seconds +function M.datetime_diff(dt1, dt2) + local pattern = "(%d+)-(%d+)-(%d+) (%d+):(%d+):(%d+)" + local y1, m1, d1, h1, min1, s1 = dt1:match(pattern) + local y2, m2, d2, h2, min2, s2 = dt2:match(pattern) + + local time1 = os.time({ year = y1, month = m1, day = d1, hour = h1, min = min1, sec = s1 }) + local time2 = os.time({ year = y2, month = m2, day = d2, hour = h2, min = min2, sec = s2 }) + + local delta_seconds = os.difftime(time2, time1) + return delta_seconds +end + +---@param iso_str string +---@return string|nil +---@return string|nil error +function M.parse_iso8601_date(iso_str) + local year, month, day, hour, min, sec = iso_str:match("(%d+)-(%d+)-(%d+)T(%d+):(%d+):(%d+)Z") + if not year then return nil, "Invalid ISO 8601 format" end + + local time_table = { + year = tonumber(year), + month = tonumber(month), + day = tonumber(day), + hour = tonumber(hour), + min = tonumber(min), + sec = tonumber(sec), + isdst = false, + } + + local timestamp = os.time(time_table) + + return tostring(os.date("%Y-%m-%d %H:%M:%S", timestamp)), nil +end + return M