fix: rag search tool is asynchronous (#1927)

This commit is contained in:
Peter Cardenas
2025-05-12 01:54:14 -07:00
committed by GitHub
parent adae032f5f
commit aff9dea03c
2 changed files with 31 additions and 21 deletions

View File

@@ -502,16 +502,24 @@ function M.git_commit(opts, on_log, on_complete)
end end
---@type AvanteLLMToolFunc<{ query: string }> ---@type AvanteLLMToolFunc<{ query: string }>
function M.rag_search(opts, on_log) function M.rag_search(opts, on_log, on_complete)
if not Config.rag_service.enabled then return nil, "Rag service is not enabled" end if not Config.rag_service.enabled then return nil, "Rag service is not enabled" end
if not opts.query then return nil, "No query provided" end if not opts.query then return nil, "No query provided" end
if on_log then on_log("query: " .. opts.query) end if on_log then on_log("query: " .. opts.query) end
local root = Utils.get_project_root() local root = Utils.get_project_root()
local uri = "file://" .. root local uri = "file://" .. root
if uri:sub(-1) ~= "/" then uri = uri .. "/" end if uri:sub(-1) ~= "/" then uri = uri .. "/" end
local resp, err = RagService.retrieve(uri, opts.query) RagService.retrieve(
if err then return nil, err end uri,
return vim.json.encode(resp), nil opts.query,
vim.schedule_wrap(function(resp, err)
if err then
on_complete(nil, err)
return
end
on_complete(vim.json.encode(resp), nil)
end)
)
end end
---@type AvanteLLMToolFunc<{ code: string, rel_path: string, container_image?: string }> ---@type AvanteLLMToolFunc<{ code: string, rel_path: string, container_image?: string }>

View File

@@ -299,11 +299,10 @@ end
---@param base_uri string ---@param base_uri string
---@param query string ---@param query string
---@return AvanteRagServiceRetrieveResponse | nil resp ---@param on_complete fun(resp: AvanteRagServiceRetrieveResponse | nil, error: string | nil): nil
---@return string | nil error function M.retrieve(base_uri, query, on_complete)
function M.retrieve(base_uri, query)
base_uri = M.to_container_uri(base_uri) base_uri = M.to_container_uri(base_uri)
local resp = curl.post(M.get_rag_service_url() .. "/api/v1/retrieve", { curl.post(M.get_rag_service_url() .. "/api/v1/retrieve", {
headers = { headers = {
["Content-Type"] = "application/json", ["Content-Type"] = "application/json",
}, },
@@ -313,20 +312,23 @@ function M.retrieve(base_uri, query)
top_k = 10, top_k = 10,
}), }),
timeout = 100000, timeout = 100000,
callback = function(resp)
if resp.status ~= 200 then
Utils.error("failed to retrieve: " .. resp.body)
on_complete(nil, resp.body)
return
end
local jsn = vim.json.decode(resp.body)
jsn.sources = vim
.iter(jsn.sources)
:map(function(source)
local uri = M.to_local_uri(source.uri)
return vim.tbl_deep_extend("force", source, { uri = uri })
end)
:totable()
on_complete(jsn, nil)
end,
}) })
if resp.status ~= 200 then
Utils.error("failed to retrieve: " .. resp.body)
return nil, "failed to retrieve: " .. resp.body
end
local jsn = vim.json.decode(resp.body)
jsn.sources = vim
.iter(jsn.sources)
:map(function(source)
local uri = M.to_local_uri(source.uri)
return vim.tbl_deep_extend("force", source, { uri = uri })
end)
:totable()
return jsn, nil
end end
---@class AvanteRagServiceIndexingStatusSummary ---@class AvanteRagServiceIndexingStatusSummary