feat: RAG service (#1220)

This commit is contained in:
yetone
2025-02-23 01:37:26 +08:00
committed by GitHub
parent 437d36920d
commit fd84c91cdb
32 changed files with 2339 additions and 15 deletions

View File

@@ -32,6 +32,9 @@ M._defaults = {
-- For most providers that we support we will determine this automatically.
-- If you wish to use a given implementation, then you can override it here.
tokenizer = "tiktoken",
rag_service = {
enabled = false, -- Enables the rag service, requires OPENAI_API_KEY to be set
},
web_search_engine = {
provider = "tavily",
providers = {

View File

@@ -6,6 +6,7 @@ local Selection = require("avante.selection")
local Suggestion = require("avante.suggestion")
local Config = require("avante.config")
local Diff = require("avante.diff")
local RagService = require("avante.rag_service")
---@class Avante
local M = {
@@ -383,6 +384,41 @@ function M.setup(opts)
H.signs()
M.did_setup = true
local function run_rag_service()
local started_at = os.time()
local add_resource_with_delay
local function add_resource()
local is_ready = RagService.is_ready()
if not is_ready then
local elapsed = os.time() - started_at
if elapsed > 1000 * 60 * 15 then
Utils.warn("Rag Service is not ready, giving up")
return
end
add_resource_with_delay()
return
end
vim.defer_fn(function()
Utils.info("Adding project root to Rag Service ...")
local uri = "file://" .. Utils.get_project_root()
if uri:sub(-1) ~= "/" then uri = uri .. "/" end
RagService.add_resource(uri)
Utils.info("Added project root to Rag Service")
end, 5000)
end
add_resource_with_delay = function()
vim.defer_fn(function() add_resource() end, 5000)
end
vim.schedule(function()
Utils.info("Starting Rag Service ...")
RagService.launch_rag_service()
Utils.info("Launched Rag Service")
add_resource_with_delay()
end)
end
if Config.rag_service.enabled then run_rag_service() end
end
return M

View File

@@ -2,6 +2,9 @@ local curl = require("plenary.curl")
local Utils = require("avante.utils")
local Path = require("plenary.path")
local Config = require("avante.config")
local RagService = require("avante.rag_service")
---@class AvanteRagService
local M = {}
---@param rel_path string
@@ -533,6 +536,22 @@ function M.git_commit(opts, on_log)
return true, nil
end
---@param opts { query: string }
---@param on_log? fun(log: string): nil
---@return string|nil result
---@return string|nil error
function M.rag_search(opts, on_log)
if not Config.rag_service.enabled then return nil, "Rag service is not enabled" end
if not opts.query then return nil, "No query provided" end
if on_log then on_log("query: " .. opts.query) end
local root = Utils.get_project_root()
local uri = "file://" .. root
if uri:sub(-1) ~= "/" then uri = uri .. "/" end
local resp, err = RagService.retrieve(uri, opts.query)
if err then return nil, err end
return vim.json.encode(resp), nil
end
---@param opts { code: string, rel_path: string }
---@param on_log? fun(log: string): nil
---@return string|nil result
@@ -554,8 +573,39 @@ function M.python(opts, on_log)
return output, nil
end
---@return AvanteLLMTool[]
function M.get_tools() return M._tools end
---@type AvanteLLMTool[]
M.tools = {
M._tools = {
{
name = "rag_search",
enabled = function() return Config.rag_service.enabled and RagService.is_ready() end,
description = "Use Retrieval-Augmented Generation (RAG) to search for relevant information from an external knowledge base or documents. This tool retrieves relevant context from a large dataset and integrates it into the response generation process, improving accuracy and relevance. Use it when answering questions that require factual knowledge beyond what the model has been trained on.",
param = {
type = "table",
fields = {
{
name = "query",
description = "Query to search",
type = "string",
},
},
},
returns = {
{
name = "result",
description = "Result of the search",
type = "string",
},
{
name = "error",
description = "Error message if the search was not successful",
type = "string",
optional = true,
},
},
},
{
name = "python",
description = "Run python code",

302
lua/avante/rag_service.lua Normal file
View File

@@ -0,0 +1,302 @@
local curl = require("plenary.curl")
local Path = require("plenary.path")
local Utils = require("avante.utils")
local M = {}
local container_name = "avante-rag-service"
function M.get_rag_service_image() return "ghcr.io/yetone/avante-rag-service:0.0.3" end
function M.get_rag_service_port() return 20250 end
function M.get_rag_service_url() return string.format("http://localhost:%d", M.get_rag_service_port()) end
function M.get_data_path()
local p = Path:new(vim.fn.stdpath("data")):joinpath("avante/rag_service")
if not p:exists() then p:mkdir({ parents = true }) end
return p
end
function M.get_current_image()
local cmd = string.format("docker inspect %s | grep Image | grep %s", container_name, container_name)
local result = vim.fn.system(cmd)
if result == "" then return nil end
local exit_code = vim.v.shell_error
if exit_code ~= 0 then return nil end
local image = result:match('"Image":%s*"(.*)"')
if image == nil then return nil end
return image
end
---@return boolean already_running
function M.launch_rag_service()
local openai_api_key = os.getenv("OPENAI_API_KEY")
if openai_api_key == nil then
error("cannot launch avante rag service, OPENAI_API_KEY is not set")
return false
end
local openai_base_url = os.getenv("OPENAI_BASE_URL")
if openai_base_url == nil then openai_base_url = "https://api.openai.com/v1" end
local port = M.get_rag_service_port()
local image = M.get_rag_service_image()
local data_path = M.get_data_path()
local cmd = string.format("docker ps -a | grep '%s'", container_name)
local result = vim.fn.system(cmd)
if result ~= "" then
Utils.debug(string.format("container %s already running", container_name))
local current_image = M.get_current_image()
if current_image == image then return false end
Utils.debug(
string.format(
"container %s is running with different image: %s != %s, stopping...",
container_name,
current_image,
image
)
)
M.stop_rag_service()
else
Utils.debug(string.format("container %s not found, starting...", container_name))
end
local cmd_ = string.format(
"docker run -d -p %d:8000 --name %s -v %s:/data -v /:/host -e DATA_DIR=/data -e OPENAI_API_KEY=%s -e OPENAI_BASE_URL=%s %s",
port,
container_name,
data_path,
openai_api_key,
openai_base_url,
image
)
vim.fn.system(cmd_)
Utils.debug(string.format("container %s started", container_name))
return true
end
function M.stop_rag_service()
local cmd = string.format("docker ps -a | grep '%s'", container_name)
local result = vim.fn.system(cmd)
if result ~= "" then vim.fn.system(string.format("docker rm -fv %s", container_name)) end
end
function M.get_rag_service_status()
local cmd = string.format("docker ps -a | grep '%s'", container_name)
local result = vim.fn.system(cmd)
if result == "" then
return "running"
else
return "stopped"
end
end
function M.get_scheme(uri)
local scheme = uri:match("^(%w+)://")
if scheme == nil then return "unknown" end
return scheme
end
function M.to_container_uri(uri)
local scheme = M.get_scheme(uri)
if scheme == "file" then
local path = uri:match("^file://(.*)$")
uri = string.format("file:///host%s", path)
end
return uri
end
function M.to_local_uri(uri)
local scheme = M.get_scheme(uri)
if scheme == "file" then
local path = uri:match("^file://host(.*)$")
uri = string.format("file://%s", path)
end
return uri
end
function M.is_ready()
vim.fn.system(string.format("curl -s -o /dev/null -w '%%{http_code}' %s", M.get_rag_service_url()))
return vim.v.shell_error == 0
end
---@class AvanteRagServiceAddResourceResponse
---@field status string
---@field message string
---@param uri string
---@return AvanteRagServiceAddResourceResponse | nil
function M.add_resource(uri)
uri = M.to_container_uri(uri)
local resource_name = uri:match("([^/]+)/$")
local resources_resp = M.get_resources()
if resources_resp == nil then
Utils.error("Failed to get resources")
return nil
end
local already_added = false
for _, resource in ipairs(resources_resp.resources) do
if resource.uri == uri then
already_added = true
resource_name = resource.name
break
end
end
if not already_added then
local names_map = {}
for _, resource in ipairs(resources_resp.resources) do
names_map[resource.name] = true
end
if names_map[resource_name] then
for i = 1, 100 do
local resource_name_ = string.format("%s-%d", resource_name, i)
if not names_map[resource_name_] then
resource_name = resource_name_
break
end
end
if names_map[resource_name] then
Utils.error(string.format("Failed to add resource, name conflict: %s", resource_name))
return nil
end
end
end
local resp = curl.post(M.get_rag_service_url() .. "/api/v1/add_resource", {
headers = {
["Content-Type"] = "application/json",
},
body = vim.json.encode({
name = resource_name,
uri = uri,
}),
})
if resp.status ~= 200 then
Utils.error("failed to add resource: " .. resp.body)
return
end
return vim.json.decode(resp.body)
end
function M.remove_resource(uri)
uri = M.to_container_uri(uri)
local resp = curl.post(M.get_rag_service_url() .. "/api/v1/remove_resource", {
headers = {
["Content-Type"] = "application/json",
},
body = vim.json.encode({
uri = uri,
}),
})
if resp.status ~= 200 then
Utils.error("failed to remove resource: " .. resp.body)
return
end
return vim.json.decode(resp.body)
end
---@class AvanteRagServiceRetrieveSource
---@field uri string
---@field content string
---@class AvanteRagServiceRetrieveResponse
---@field response string
---@field sources AvanteRagServiceRetrieveSource[]
---@param base_uri string
---@param query string
---@return AvanteRagServiceRetrieveResponse | nil resp
---@return string | nil error
function M.retrieve(base_uri, query)
base_uri = M.to_container_uri(base_uri)
local resp = curl.post(M.get_rag_service_url() .. "/api/v1/retrieve", {
headers = {
["Content-Type"] = "application/json",
},
body = vim.json.encode({
base_uri = base_uri,
query = query,
top_k = 10,
}),
})
if resp.status ~= 200 then
Utils.error("failed to retrieve: " .. resp.body)
return nil, "failed to retrieve: " .. resp.body
end
local jsn = vim.json.decode(resp.body)
jsn.sources = vim
.iter(jsn.sources)
:map(function(source)
local uri = M.to_local_uri(source.uri)
return vim.tbl_deep_extend("force", source, { uri = uri })
end)
:totable()
return jsn, nil
end
---@class AvanteRagServiceIndexingStatusSummary
---@field indexing integer
---@field completed integer
---@field failed integer
---@class AvanteRagServiceIndexingStatusResponse
---@field uri string
---@field is_watched boolean
---@field total_files integer
---@field status_summary AvanteRagServiceIndexingStatusSummary
---@param uri string
---@return AvanteRagServiceIndexingStatusResponse | nil
function M.indexing_status(uri)
uri = M.to_container_uri(uri)
local resp = curl.post(M.get_rag_service_url() .. "/api/v1/indexing_status", {
headers = {
["Content-Type"] = "application/json",
},
body = vim.json.encode({
uri = uri,
}),
})
if resp.status ~= 200 then
Utils.error("Failed to get indexing status: " .. resp.body)
return
end
local jsn = vim.json.decode(resp.body)
jsn.uri = M.to_local_uri(jsn.uri)
return jsn
end
---@class AvanteRagServiceResource
---@field name string
---@field uri string
---@field type string
---@field status string
---@field indexing_status string
---@field created_at string
---@field indexing_started_at string | nil
---@field last_indexed_at string | nil
---@class AvanteRagServiceResourceListResponse
---@field resources AvanteRagServiceResource[]
---@field total_count number
---@return AvanteRagServiceResourceListResponse | nil
M.get_resources = function()
local resp = curl.get(M.get_rag_service_url() .. "/api/v1/resources", {
headers = {
["Content-Type"] = "application/json",
},
})
if resp.status ~= 200 then
Utils.error("Failed to get resources: " .. resp.body)
return
end
local jsn = vim.json.decode(resp.body)
jsn.resources = vim
.iter(jsn.resources)
:map(function(resource)
local uri = M.to_local_uri(resource.uri)
return vim.tbl_deep_extend("force", resource, { uri = uri })
end)
:totable()
return jsn
end
return M

View File

@@ -1227,7 +1227,7 @@ function Sidebar:apply(current_cursor)
if last_orig_diff_end_line > #original_code_lines then
pcall(function() api.nvim_win_set_cursor(winid, { #original_code_lines, 0 }) end)
else
api.nvim_win_set_cursor(winid, { last_orig_diff_end_line, 0 })
pcall(function() api.nvim_win_set_cursor(winid, { last_orig_diff_end_line, 0 }) end)
end
vim.cmd("normal! zz")
end,
@@ -2287,7 +2287,7 @@ function Sidebar:create_input_container(opts)
local chat_history = Path.history.load(self.code.bufnr)
local tools = vim.deepcopy(LLMTools.tools)
local tools = vim.deepcopy(LLMTools.get_tools())
table.insert(tools, {
name = "add_file_to_context",
description = "Add a file to the context",

View File

@@ -2,8 +2,10 @@ Don't directly search for code context in historical messages. Instead, prioriti
Tools Usage Guide:
- You have access to tools, but only use them when necessary. If a tool is not required, respond as normal.
- If you encounter a URL, prioritize using the fetch tool to obtain its content.
- If you have information that you don't know, please proactively use the tools provided by users! Especially the web search tool.
- If the `rag_search` tool exists, prioritize using it to do the search!
- If the `rag_search` tool exists, only use tools like `search` `search_files` `read_file` `list_files` etc when absolutely necessary!
- If you encounter a URL, prioritize using the `fetch` tool to obtain its content.
- If you have information that you don't know, please proactively use the tools provided by users! Especially the `web_search` tool.
- When available tools cannot meet the requirements, please try to use the `run_command` tool to solve the problem whenever possible.
- When attempting to modify a file that is not in the context, please first use the `list_files` tool and `search_files` tool to check if the file you want to modify exists, then use the `read_file` tool to read the file content. Don't modify blindly!
- When generating files, first use `list_files` tool to read the directory structure, don't generate blindly!

View File

@@ -329,6 +329,7 @@ vim.g.avante_login = vim.g.avante_login
---@field func? fun(input: any): (string | nil, string | nil)
---@field param AvanteLLMToolParam
---@field returns AvanteLLMToolReturn[]
---@field enabled? fun(): boolean
---@class AvanteLLMToolParam
---@field type string