From aff9dea03ce076170fe3b2a12c678179c78819a5 Mon Sep 17 00:00:00 2001 From: Peter Cardenas <16930781+PeterCardenas@users.noreply.github.com> Date: Mon, 12 May 2025 01:54:14 -0700 Subject: [PATCH] fix: rag search tool is asynchronous (#1927) --- lua/avante/llm_tools/init.lua | 16 ++++++++++++---- lua/avante/rag_service.lua | 36 ++++++++++++++++++----------------- 2 files changed, 31 insertions(+), 21 deletions(-) diff --git a/lua/avante/llm_tools/init.lua b/lua/avante/llm_tools/init.lua index 6adca75..a748c11 100644 --- a/lua/avante/llm_tools/init.lua +++ b/lua/avante/llm_tools/init.lua @@ -502,16 +502,24 @@ function M.git_commit(opts, on_log, on_complete) end ---@type AvanteLLMToolFunc<{ query: string }> -function M.rag_search(opts, on_log) +function M.rag_search(opts, on_log, on_complete) if not Config.rag_service.enabled then return nil, "Rag service is not enabled" end if not opts.query then return nil, "No query provided" end if on_log then on_log("query: " .. opts.query) end local root = Utils.get_project_root() local uri = "file://" .. root if uri:sub(-1) ~= "/" then uri = uri .. "/" end - local resp, err = RagService.retrieve(uri, opts.query) - if err then return nil, err end - return vim.json.encode(resp), nil + RagService.retrieve( + uri, + opts.query, + vim.schedule_wrap(function(resp, err) + if err then + on_complete(nil, err) + return + end + on_complete(vim.json.encode(resp), nil) + end) + ) end ---@type AvanteLLMToolFunc<{ code: string, rel_path: string, container_image?: string }> diff --git a/lua/avante/rag_service.lua b/lua/avante/rag_service.lua index 7f52eac..10b089b 100644 --- a/lua/avante/rag_service.lua +++ b/lua/avante/rag_service.lua @@ -299,11 +299,10 @@ end ---@param base_uri string ---@param query string ----@return AvanteRagServiceRetrieveResponse | nil resp ----@return string | nil error -function M.retrieve(base_uri, query) +---@param on_complete fun(resp: AvanteRagServiceRetrieveResponse | nil, error: string | nil): nil +function M.retrieve(base_uri, query, on_complete) base_uri = M.to_container_uri(base_uri) - local resp = curl.post(M.get_rag_service_url() .. "/api/v1/retrieve", { + curl.post(M.get_rag_service_url() .. "/api/v1/retrieve", { headers = { ["Content-Type"] = "application/json", }, @@ -313,20 +312,23 @@ function M.retrieve(base_uri, query) top_k = 10, }), timeout = 100000, + callback = function(resp) + if resp.status ~= 200 then + Utils.error("failed to retrieve: " .. resp.body) + on_complete(nil, resp.body) + return + end + local jsn = vim.json.decode(resp.body) + jsn.sources = vim + .iter(jsn.sources) + :map(function(source) + local uri = M.to_local_uri(source.uri) + return vim.tbl_deep_extend("force", source, { uri = uri }) + end) + :totable() + on_complete(jsn, nil) + end, }) - if resp.status ~= 200 then - Utils.error("failed to retrieve: " .. resp.body) - return nil, "failed to retrieve: " .. resp.body - end - local jsn = vim.json.decode(resp.body) - jsn.sources = vim - .iter(jsn.sources) - :map(function(source) - local uri = M.to_local_uri(source.uri) - return vim.tbl_deep_extend("force", source, { uri = uri }) - end) - :totable() - return jsn, nil end ---@class AvanteRagServiceIndexingStatusSummary