feat: RAG service (#1220)
This commit is contained in:
@@ -32,6 +32,9 @@ M._defaults = {
|
||||
-- For most providers that we support we will determine this automatically.
|
||||
-- If you wish to use a given implementation, then you can override it here.
|
||||
tokenizer = "tiktoken",
|
||||
rag_service = {
|
||||
enabled = false, -- Enables the rag service, requires OPENAI_API_KEY to be set
|
||||
},
|
||||
web_search_engine = {
|
||||
provider = "tavily",
|
||||
providers = {
|
||||
|
||||
@@ -6,6 +6,7 @@ local Selection = require("avante.selection")
|
||||
local Suggestion = require("avante.suggestion")
|
||||
local Config = require("avante.config")
|
||||
local Diff = require("avante.diff")
|
||||
local RagService = require("avante.rag_service")
|
||||
|
||||
---@class Avante
|
||||
local M = {
|
||||
@@ -383,6 +384,41 @@ function M.setup(opts)
|
||||
H.signs()
|
||||
|
||||
M.did_setup = true
|
||||
|
||||
local function run_rag_service()
|
||||
local started_at = os.time()
|
||||
local add_resource_with_delay
|
||||
local function add_resource()
|
||||
local is_ready = RagService.is_ready()
|
||||
if not is_ready then
|
||||
local elapsed = os.time() - started_at
|
||||
if elapsed > 1000 * 60 * 15 then
|
||||
Utils.warn("Rag Service is not ready, giving up")
|
||||
return
|
||||
end
|
||||
add_resource_with_delay()
|
||||
return
|
||||
end
|
||||
vim.defer_fn(function()
|
||||
Utils.info("Adding project root to Rag Service ...")
|
||||
local uri = "file://" .. Utils.get_project_root()
|
||||
if uri:sub(-1) ~= "/" then uri = uri .. "/" end
|
||||
RagService.add_resource(uri)
|
||||
Utils.info("Added project root to Rag Service")
|
||||
end, 5000)
|
||||
end
|
||||
add_resource_with_delay = function()
|
||||
vim.defer_fn(function() add_resource() end, 5000)
|
||||
end
|
||||
vim.schedule(function()
|
||||
Utils.info("Starting Rag Service ...")
|
||||
RagService.launch_rag_service()
|
||||
Utils.info("Launched Rag Service")
|
||||
add_resource_with_delay()
|
||||
end)
|
||||
end
|
||||
|
||||
if Config.rag_service.enabled then run_rag_service() end
|
||||
end
|
||||
|
||||
return M
|
||||
|
||||
@@ -2,6 +2,9 @@ local curl = require("plenary.curl")
|
||||
local Utils = require("avante.utils")
|
||||
local Path = require("plenary.path")
|
||||
local Config = require("avante.config")
|
||||
local RagService = require("avante.rag_service")
|
||||
|
||||
---@class AvanteRagService
|
||||
local M = {}
|
||||
|
||||
---@param rel_path string
|
||||
@@ -533,6 +536,22 @@ function M.git_commit(opts, on_log)
|
||||
return true, nil
|
||||
end
|
||||
|
||||
---@param opts { query: string }
|
||||
---@param on_log? fun(log: string): nil
|
||||
---@return string|nil result
|
||||
---@return string|nil error
|
||||
function M.rag_search(opts, on_log)
|
||||
if not Config.rag_service.enabled then return nil, "Rag service is not enabled" end
|
||||
if not opts.query then return nil, "No query provided" end
|
||||
if on_log then on_log("query: " .. opts.query) end
|
||||
local root = Utils.get_project_root()
|
||||
local uri = "file://" .. root
|
||||
if uri:sub(-1) ~= "/" then uri = uri .. "/" end
|
||||
local resp, err = RagService.retrieve(uri, opts.query)
|
||||
if err then return nil, err end
|
||||
return vim.json.encode(resp), nil
|
||||
end
|
||||
|
||||
---@param opts { code: string, rel_path: string }
|
||||
---@param on_log? fun(log: string): nil
|
||||
---@return string|nil result
|
||||
@@ -554,8 +573,39 @@ function M.python(opts, on_log)
|
||||
return output, nil
|
||||
end
|
||||
|
||||
---@return AvanteLLMTool[]
|
||||
function M.get_tools() return M._tools end
|
||||
|
||||
---@type AvanteLLMTool[]
|
||||
M.tools = {
|
||||
M._tools = {
|
||||
{
|
||||
name = "rag_search",
|
||||
enabled = function() return Config.rag_service.enabled and RagService.is_ready() end,
|
||||
description = "Use Retrieval-Augmented Generation (RAG) to search for relevant information from an external knowledge base or documents. This tool retrieves relevant context from a large dataset and integrates it into the response generation process, improving accuracy and relevance. Use it when answering questions that require factual knowledge beyond what the model has been trained on.",
|
||||
param = {
|
||||
type = "table",
|
||||
fields = {
|
||||
{
|
||||
name = "query",
|
||||
description = "Query to search",
|
||||
type = "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
returns = {
|
||||
{
|
||||
name = "result",
|
||||
description = "Result of the search",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "error",
|
||||
description = "Error message if the search was not successful",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name = "python",
|
||||
description = "Run python code",
|
||||
|
||||
302
lua/avante/rag_service.lua
Normal file
302
lua/avante/rag_service.lua
Normal file
@@ -0,0 +1,302 @@
|
||||
local curl = require("plenary.curl")
|
||||
local Path = require("plenary.path")
|
||||
local Utils = require("avante.utils")
|
||||
|
||||
local M = {}
|
||||
|
||||
local container_name = "avante-rag-service"
|
||||
|
||||
function M.get_rag_service_image() return "ghcr.io/yetone/avante-rag-service:0.0.3" end
|
||||
|
||||
function M.get_rag_service_port() return 20250 end
|
||||
|
||||
function M.get_rag_service_url() return string.format("http://localhost:%d", M.get_rag_service_port()) end
|
||||
|
||||
function M.get_data_path()
|
||||
local p = Path:new(vim.fn.stdpath("data")):joinpath("avante/rag_service")
|
||||
if not p:exists() then p:mkdir({ parents = true }) end
|
||||
return p
|
||||
end
|
||||
|
||||
function M.get_current_image()
|
||||
local cmd = string.format("docker inspect %s | grep Image | grep %s", container_name, container_name)
|
||||
local result = vim.fn.system(cmd)
|
||||
if result == "" then return nil end
|
||||
local exit_code = vim.v.shell_error
|
||||
if exit_code ~= 0 then return nil end
|
||||
local image = result:match('"Image":%s*"(.*)"')
|
||||
if image == nil then return nil end
|
||||
return image
|
||||
end
|
||||
|
||||
---@return boolean already_running
|
||||
function M.launch_rag_service()
|
||||
local openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
if openai_api_key == nil then
|
||||
error("cannot launch avante rag service, OPENAI_API_KEY is not set")
|
||||
return false
|
||||
end
|
||||
local openai_base_url = os.getenv("OPENAI_BASE_URL")
|
||||
if openai_base_url == nil then openai_base_url = "https://api.openai.com/v1" end
|
||||
local port = M.get_rag_service_port()
|
||||
local image = M.get_rag_service_image()
|
||||
local data_path = M.get_data_path()
|
||||
local cmd = string.format("docker ps -a | grep '%s'", container_name)
|
||||
local result = vim.fn.system(cmd)
|
||||
if result ~= "" then
|
||||
Utils.debug(string.format("container %s already running", container_name))
|
||||
local current_image = M.get_current_image()
|
||||
if current_image == image then return false end
|
||||
Utils.debug(
|
||||
string.format(
|
||||
"container %s is running with different image: %s != %s, stopping...",
|
||||
container_name,
|
||||
current_image,
|
||||
image
|
||||
)
|
||||
)
|
||||
M.stop_rag_service()
|
||||
else
|
||||
Utils.debug(string.format("container %s not found, starting...", container_name))
|
||||
end
|
||||
local cmd_ = string.format(
|
||||
"docker run -d -p %d:8000 --name %s -v %s:/data -v /:/host -e DATA_DIR=/data -e OPENAI_API_KEY=%s -e OPENAI_BASE_URL=%s %s",
|
||||
port,
|
||||
container_name,
|
||||
data_path,
|
||||
openai_api_key,
|
||||
openai_base_url,
|
||||
image
|
||||
)
|
||||
vim.fn.system(cmd_)
|
||||
Utils.debug(string.format("container %s started", container_name))
|
||||
return true
|
||||
end
|
||||
|
||||
function M.stop_rag_service()
|
||||
local cmd = string.format("docker ps -a | grep '%s'", container_name)
|
||||
local result = vim.fn.system(cmd)
|
||||
if result ~= "" then vim.fn.system(string.format("docker rm -fv %s", container_name)) end
|
||||
end
|
||||
|
||||
function M.get_rag_service_status()
|
||||
local cmd = string.format("docker ps -a | grep '%s'", container_name)
|
||||
local result = vim.fn.system(cmd)
|
||||
if result == "" then
|
||||
return "running"
|
||||
else
|
||||
return "stopped"
|
||||
end
|
||||
end
|
||||
|
||||
function M.get_scheme(uri)
|
||||
local scheme = uri:match("^(%w+)://")
|
||||
if scheme == nil then return "unknown" end
|
||||
return scheme
|
||||
end
|
||||
|
||||
function M.to_container_uri(uri)
|
||||
local scheme = M.get_scheme(uri)
|
||||
if scheme == "file" then
|
||||
local path = uri:match("^file://(.*)$")
|
||||
uri = string.format("file:///host%s", path)
|
||||
end
|
||||
return uri
|
||||
end
|
||||
|
||||
function M.to_local_uri(uri)
|
||||
local scheme = M.get_scheme(uri)
|
||||
if scheme == "file" then
|
||||
local path = uri:match("^file://host(.*)$")
|
||||
uri = string.format("file://%s", path)
|
||||
end
|
||||
return uri
|
||||
end
|
||||
|
||||
function M.is_ready()
|
||||
vim.fn.system(string.format("curl -s -o /dev/null -w '%%{http_code}' %s", M.get_rag_service_url()))
|
||||
return vim.v.shell_error == 0
|
||||
end
|
||||
|
||||
---@class AvanteRagServiceAddResourceResponse
|
||||
---@field status string
|
||||
---@field message string
|
||||
|
||||
---@param uri string
|
||||
---@return AvanteRagServiceAddResourceResponse | nil
|
||||
function M.add_resource(uri)
|
||||
uri = M.to_container_uri(uri)
|
||||
local resource_name = uri:match("([^/]+)/$")
|
||||
local resources_resp = M.get_resources()
|
||||
if resources_resp == nil then
|
||||
Utils.error("Failed to get resources")
|
||||
return nil
|
||||
end
|
||||
local already_added = false
|
||||
for _, resource in ipairs(resources_resp.resources) do
|
||||
if resource.uri == uri then
|
||||
already_added = true
|
||||
resource_name = resource.name
|
||||
break
|
||||
end
|
||||
end
|
||||
if not already_added then
|
||||
local names_map = {}
|
||||
for _, resource in ipairs(resources_resp.resources) do
|
||||
names_map[resource.name] = true
|
||||
end
|
||||
if names_map[resource_name] then
|
||||
for i = 1, 100 do
|
||||
local resource_name_ = string.format("%s-%d", resource_name, i)
|
||||
if not names_map[resource_name_] then
|
||||
resource_name = resource_name_
|
||||
break
|
||||
end
|
||||
end
|
||||
if names_map[resource_name] then
|
||||
Utils.error(string.format("Failed to add resource, name conflict: %s", resource_name))
|
||||
return nil
|
||||
end
|
||||
end
|
||||
end
|
||||
local resp = curl.post(M.get_rag_service_url() .. "/api/v1/add_resource", {
|
||||
headers = {
|
||||
["Content-Type"] = "application/json",
|
||||
},
|
||||
body = vim.json.encode({
|
||||
name = resource_name,
|
||||
uri = uri,
|
||||
}),
|
||||
})
|
||||
if resp.status ~= 200 then
|
||||
Utils.error("failed to add resource: " .. resp.body)
|
||||
return
|
||||
end
|
||||
return vim.json.decode(resp.body)
|
||||
end
|
||||
|
||||
function M.remove_resource(uri)
|
||||
uri = M.to_container_uri(uri)
|
||||
local resp = curl.post(M.get_rag_service_url() .. "/api/v1/remove_resource", {
|
||||
headers = {
|
||||
["Content-Type"] = "application/json",
|
||||
},
|
||||
body = vim.json.encode({
|
||||
uri = uri,
|
||||
}),
|
||||
})
|
||||
if resp.status ~= 200 then
|
||||
Utils.error("failed to remove resource: " .. resp.body)
|
||||
return
|
||||
end
|
||||
return vim.json.decode(resp.body)
|
||||
end
|
||||
|
||||
---@class AvanteRagServiceRetrieveSource
|
||||
---@field uri string
|
||||
---@field content string
|
||||
|
||||
---@class AvanteRagServiceRetrieveResponse
|
||||
---@field response string
|
||||
---@field sources AvanteRagServiceRetrieveSource[]
|
||||
|
||||
---@param base_uri string
|
||||
---@param query string
|
||||
---@return AvanteRagServiceRetrieveResponse | nil resp
|
||||
---@return string | nil error
|
||||
function M.retrieve(base_uri, query)
|
||||
base_uri = M.to_container_uri(base_uri)
|
||||
local resp = curl.post(M.get_rag_service_url() .. "/api/v1/retrieve", {
|
||||
headers = {
|
||||
["Content-Type"] = "application/json",
|
||||
},
|
||||
body = vim.json.encode({
|
||||
base_uri = base_uri,
|
||||
query = query,
|
||||
top_k = 10,
|
||||
}),
|
||||
})
|
||||
if resp.status ~= 200 then
|
||||
Utils.error("failed to retrieve: " .. resp.body)
|
||||
return nil, "failed to retrieve: " .. resp.body
|
||||
end
|
||||
local jsn = vim.json.decode(resp.body)
|
||||
jsn.sources = vim
|
||||
.iter(jsn.sources)
|
||||
:map(function(source)
|
||||
local uri = M.to_local_uri(source.uri)
|
||||
return vim.tbl_deep_extend("force", source, { uri = uri })
|
||||
end)
|
||||
:totable()
|
||||
return jsn, nil
|
||||
end
|
||||
|
||||
---@class AvanteRagServiceIndexingStatusSummary
|
||||
---@field indexing integer
|
||||
---@field completed integer
|
||||
---@field failed integer
|
||||
|
||||
---@class AvanteRagServiceIndexingStatusResponse
|
||||
---@field uri string
|
||||
---@field is_watched boolean
|
||||
---@field total_files integer
|
||||
---@field status_summary AvanteRagServiceIndexingStatusSummary
|
||||
|
||||
---@param uri string
|
||||
---@return AvanteRagServiceIndexingStatusResponse | nil
|
||||
function M.indexing_status(uri)
|
||||
uri = M.to_container_uri(uri)
|
||||
local resp = curl.post(M.get_rag_service_url() .. "/api/v1/indexing_status", {
|
||||
headers = {
|
||||
["Content-Type"] = "application/json",
|
||||
},
|
||||
body = vim.json.encode({
|
||||
uri = uri,
|
||||
}),
|
||||
})
|
||||
if resp.status ~= 200 then
|
||||
Utils.error("Failed to get indexing status: " .. resp.body)
|
||||
return
|
||||
end
|
||||
local jsn = vim.json.decode(resp.body)
|
||||
jsn.uri = M.to_local_uri(jsn.uri)
|
||||
return jsn
|
||||
end
|
||||
|
||||
---@class AvanteRagServiceResource
|
||||
---@field name string
|
||||
---@field uri string
|
||||
---@field type string
|
||||
---@field status string
|
||||
---@field indexing_status string
|
||||
---@field created_at string
|
||||
---@field indexing_started_at string | nil
|
||||
---@field last_indexed_at string | nil
|
||||
|
||||
---@class AvanteRagServiceResourceListResponse
|
||||
---@field resources AvanteRagServiceResource[]
|
||||
---@field total_count number
|
||||
|
||||
---@return AvanteRagServiceResourceListResponse | nil
|
||||
M.get_resources = function()
|
||||
local resp = curl.get(M.get_rag_service_url() .. "/api/v1/resources", {
|
||||
headers = {
|
||||
["Content-Type"] = "application/json",
|
||||
},
|
||||
})
|
||||
if resp.status ~= 200 then
|
||||
Utils.error("Failed to get resources: " .. resp.body)
|
||||
return
|
||||
end
|
||||
local jsn = vim.json.decode(resp.body)
|
||||
jsn.resources = vim
|
||||
.iter(jsn.resources)
|
||||
:map(function(resource)
|
||||
local uri = M.to_local_uri(resource.uri)
|
||||
return vim.tbl_deep_extend("force", resource, { uri = uri })
|
||||
end)
|
||||
:totable()
|
||||
return jsn
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -1227,7 +1227,7 @@ function Sidebar:apply(current_cursor)
|
||||
if last_orig_diff_end_line > #original_code_lines then
|
||||
pcall(function() api.nvim_win_set_cursor(winid, { #original_code_lines, 0 }) end)
|
||||
else
|
||||
api.nvim_win_set_cursor(winid, { last_orig_diff_end_line, 0 })
|
||||
pcall(function() api.nvim_win_set_cursor(winid, { last_orig_diff_end_line, 0 }) end)
|
||||
end
|
||||
vim.cmd("normal! zz")
|
||||
end,
|
||||
@@ -2287,7 +2287,7 @@ function Sidebar:create_input_container(opts)
|
||||
|
||||
local chat_history = Path.history.load(self.code.bufnr)
|
||||
|
||||
local tools = vim.deepcopy(LLMTools.tools)
|
||||
local tools = vim.deepcopy(LLMTools.get_tools())
|
||||
table.insert(tools, {
|
||||
name = "add_file_to_context",
|
||||
description = "Add a file to the context",
|
||||
|
||||
@@ -2,8 +2,10 @@ Don't directly search for code context in historical messages. Instead, prioriti
|
||||
|
||||
Tools Usage Guide:
|
||||
- You have access to tools, but only use them when necessary. If a tool is not required, respond as normal.
|
||||
- If you encounter a URL, prioritize using the fetch tool to obtain its content.
|
||||
- If you have information that you don't know, please proactively use the tools provided by users! Especially the web search tool.
|
||||
- If the `rag_search` tool exists, prioritize using it to do the search!
|
||||
- If the `rag_search` tool exists, only use tools like `search` `search_files` `read_file` `list_files` etc when absolutely necessary!
|
||||
- If you encounter a URL, prioritize using the `fetch` tool to obtain its content.
|
||||
- If you have information that you don't know, please proactively use the tools provided by users! Especially the `web_search` tool.
|
||||
- When available tools cannot meet the requirements, please try to use the `run_command` tool to solve the problem whenever possible.
|
||||
- When attempting to modify a file that is not in the context, please first use the `list_files` tool and `search_files` tool to check if the file you want to modify exists, then use the `read_file` tool to read the file content. Don't modify blindly!
|
||||
- When generating files, first use `list_files` tool to read the directory structure, don't generate blindly!
|
||||
|
||||
@@ -329,6 +329,7 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@field func? fun(input: any): (string | nil, string | nil)
|
||||
---@field param AvanteLLMToolParam
|
||||
---@field returns AvanteLLMToolReturn[]
|
||||
---@field enabled? fun(): boolean
|
||||
|
||||
---@class AvanteLLMToolParam
|
||||
---@field type string
|
||||
|
||||
Reference in New Issue
Block a user