fix: rag search tool is asynchronous (#1927)

This commit is contained in:
Peter Cardenas
2025-05-12 01:54:14 -07:00
committed by GitHub
parent adae032f5f
commit aff9dea03c
2 changed files with 31 additions and 21 deletions

View File

@@ -502,16 +502,24 @@ function M.git_commit(opts, on_log, on_complete)
end
---@type AvanteLLMToolFunc<{ query: string }>
function M.rag_search(opts, on_log)
function M.rag_search(opts, on_log, on_complete)
if not Config.rag_service.enabled then return nil, "Rag service is not enabled" end
if not opts.query then return nil, "No query provided" end
if on_log then on_log("query: " .. opts.query) end
local root = Utils.get_project_root()
local uri = "file://" .. root
if uri:sub(-1) ~= "/" then uri = uri .. "/" end
local resp, err = RagService.retrieve(uri, opts.query)
if err then return nil, err end
return vim.json.encode(resp), nil
RagService.retrieve(
uri,
opts.query,
vim.schedule_wrap(function(resp, err)
if err then
on_complete(nil, err)
return
end
on_complete(vim.json.encode(resp), nil)
end)
)
end
---@type AvanteLLMToolFunc<{ code: string, rel_path: string, container_image?: string }>

View File

@@ -299,11 +299,10 @@ end
---@param base_uri string
---@param query string
---@return AvanteRagServiceRetrieveResponse | nil resp
---@return string | nil error
function M.retrieve(base_uri, query)
---@param on_complete fun(resp: AvanteRagServiceRetrieveResponse | nil, error: string | nil): nil
function M.retrieve(base_uri, query, on_complete)
base_uri = M.to_container_uri(base_uri)
local resp = curl.post(M.get_rag_service_url() .. "/api/v1/retrieve", {
curl.post(M.get_rag_service_url() .. "/api/v1/retrieve", {
headers = {
["Content-Type"] = "application/json",
},
@@ -313,20 +312,23 @@ function M.retrieve(base_uri, query)
top_k = 10,
}),
timeout = 100000,
callback = function(resp)
if resp.status ~= 200 then
Utils.error("failed to retrieve: " .. resp.body)
on_complete(nil, resp.body)
return
end
local jsn = vim.json.decode(resp.body)
jsn.sources = vim
.iter(jsn.sources)
:map(function(source)
local uri = M.to_local_uri(source.uri)
return vim.tbl_deep_extend("force", source, { uri = uri })
end)
:totable()
on_complete(jsn, nil)
end,
})
if resp.status ~= 200 then
Utils.error("failed to retrieve: " .. resp.body)
return nil, "failed to retrieve: " .. resp.body
end
local jsn = vim.json.decode(resp.body)
jsn.sources = vim
.iter(jsn.sources)
:map(function(source)
local uri = M.to_local_uri(source.uri)
return vim.tbl_deep_extend("force", source, { uri = uri })
end)
:totable()
return jsn, nil
end
---@class AvanteRagServiceIndexingStatusSummary