feat: RAG service (#1220)

This commit is contained in:
yetone
2025-02-23 01:37:26 +08:00
committed by GitHub
parent 437d36920d
commit fd84c91cdb
32 changed files with 2339 additions and 15 deletions

View File

@@ -13,4 +13,3 @@ tab_width = 8
[{Makefile,**/Makefile}]
indent_style = tab
indent_size = 8

31
.github/workflows/python.yaml vendored Normal file
View File

@@ -0,0 +1,31 @@
name: Python
on:
push:
branches:
- main
paths:
- '**/*.py'
- '.github/workflows/python.yaml'
pull_request:
branches:
- main
paths:
- '**/*.py'
- '.github/workflows/python.yaml'
jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: '3.11'
- run: |
python -m venv .venv
source .venv/bin/activate
pip install -r py/rag-service/requirements.txt
- uses: pre-commit/action@v3.0.1
with:
extra_args: --files $(find ./py -type f -name "*.py")

12
.gitignore vendored
View File

@@ -2,6 +2,10 @@
*.lua~
*.luac
.venv
__pycache__/
data/
# Neovim plugin specific files
plugin/packer_compiled.lua
@@ -31,14 +35,14 @@ temp/
.env
# If you use any build tools, you might need to ignore build output directories
/build/
/dist/
build/
dist/
# If you use any test frameworks, you might need to ignore test coverage reports
/coverage/
coverage/
# If you use documentation generation tools, you might need to ignore generated docs
/doc/
doc/
# If you have any personal configuration files, you should ignore them too
config.personal.lua

View File

@@ -1,12 +1,14 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
rev: v5.0.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- id: check-ast # 检查Python语法错误
- id: debug-statements # 检查是否有debug语句
- repo: https://github.com/JohnnyMorganz/StyLua
rev: v0.20.0
rev: v2.0.2
hooks:
- id: stylua-system # or stylua-system / stylua-github
files: \.lua$
@@ -15,3 +17,21 @@ repos:
hooks:
- id: fmt
files: \.rs$
- id: cargo-check
args: ['--features', 'luajit']
files: \.rs$
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.5
hooks:
# 运行 Ruff linter
- id: ruff
args: [--fix]
# 运行 Ruff formatter
- id: ruff-format
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.393
hooks:
- id: pyright
additional_dependencies:
- "types-setuptools"
- "types-requests"

View File

@@ -198,4 +198,4 @@
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.

View File

@@ -100,3 +100,8 @@ lint: luacheck luastylecheck ruststylecheck rustlint
.PHONY: lua-typecheck
lua-typecheck:
bash ./scripts/lua-typecheck.sh
.PHONY: build-image
build-image:
docker build -t ghcr.io/yetone/avante-rag-service:0.0.3 -f py/rag-service/Dockerfile py/rag-service
docker push ghcr.io/yetone/avante-rag-service:0.0.3

View File

@@ -622,6 +622,18 @@ Because avante.nvim has always used Aiders method for planning applying, but
Therefore, I have adopted Cursors method to implement planning applying. For details on the implementation, please refer to [cursor-planning-mode.md](./cursor-planning-mode.md)
## RAG Service
Avante provides a RAG service, which is a tool for obtaining the required context for the AI to generate the codes. Default it not enabled, you can enable it in this way:
```lua
rag_service = {
enabled = true, -- Enables the rag service, requires OPENAI_API_KEY to be set
},
```
Please note that since the RAG service uses OpenAI for embeddings, you must set `OPENAI_API_KEY` environment variable!
## Web Search Engines
Avante's tools include some web search engines, currently support:

View File

@@ -78,4 +78,3 @@ else
curl -L "$ARTIFACT_URL" | tar -zxv -C "$TARGET_DIR"
fi

View File

@@ -22,4 +22,3 @@
(enum_declaration
body: (enum_member_declaration_list
(enum_member_declaration) @enum_item))

View File

@@ -32,6 +32,9 @@ M._defaults = {
-- For most providers that we support we will determine this automatically.
-- If you wish to use a given implementation, then you can override it here.
tokenizer = "tiktoken",
rag_service = {
enabled = false, -- Enables the rag service, requires OPENAI_API_KEY to be set
},
web_search_engine = {
provider = "tavily",
providers = {

View File

@@ -6,6 +6,7 @@ local Selection = require("avante.selection")
local Suggestion = require("avante.suggestion")
local Config = require("avante.config")
local Diff = require("avante.diff")
local RagService = require("avante.rag_service")
---@class Avante
local M = {
@@ -383,6 +384,41 @@ function M.setup(opts)
H.signs()
M.did_setup = true
local function run_rag_service()
local started_at = os.time()
local add_resource_with_delay
local function add_resource()
local is_ready = RagService.is_ready()
if not is_ready then
local elapsed = os.time() - started_at
if elapsed > 1000 * 60 * 15 then
Utils.warn("Rag Service is not ready, giving up")
return
end
add_resource_with_delay()
return
end
vim.defer_fn(function()
Utils.info("Adding project root to Rag Service ...")
local uri = "file://" .. Utils.get_project_root()
if uri:sub(-1) ~= "/" then uri = uri .. "/" end
RagService.add_resource(uri)
Utils.info("Added project root to Rag Service")
end, 5000)
end
add_resource_with_delay = function()
vim.defer_fn(function() add_resource() end, 5000)
end
vim.schedule(function()
Utils.info("Starting Rag Service ...")
RagService.launch_rag_service()
Utils.info("Launched Rag Service")
add_resource_with_delay()
end)
end
if Config.rag_service.enabled then run_rag_service() end
end
return M

View File

@@ -2,6 +2,9 @@ local curl = require("plenary.curl")
local Utils = require("avante.utils")
local Path = require("plenary.path")
local Config = require("avante.config")
local RagService = require("avante.rag_service")
---@class AvanteRagService
local M = {}
---@param rel_path string
@@ -533,6 +536,22 @@ function M.git_commit(opts, on_log)
return true, nil
end
---@param opts { query: string }
---@param on_log? fun(log: string): nil
---@return string|nil result
---@return string|nil error
function M.rag_search(opts, on_log)
if not Config.rag_service.enabled then return nil, "Rag service is not enabled" end
if not opts.query then return nil, "No query provided" end
if on_log then on_log("query: " .. opts.query) end
local root = Utils.get_project_root()
local uri = "file://" .. root
if uri:sub(-1) ~= "/" then uri = uri .. "/" end
local resp, err = RagService.retrieve(uri, opts.query)
if err then return nil, err end
return vim.json.encode(resp), nil
end
---@param opts { code: string, rel_path: string }
---@param on_log? fun(log: string): nil
---@return string|nil result
@@ -554,8 +573,39 @@ function M.python(opts, on_log)
return output, nil
end
---@return AvanteLLMTool[]
function M.get_tools() return M._tools end
---@type AvanteLLMTool[]
M.tools = {
M._tools = {
{
name = "rag_search",
enabled = function() return Config.rag_service.enabled and RagService.is_ready() end,
description = "Use Retrieval-Augmented Generation (RAG) to search for relevant information from an external knowledge base or documents. This tool retrieves relevant context from a large dataset and integrates it into the response generation process, improving accuracy and relevance. Use it when answering questions that require factual knowledge beyond what the model has been trained on.",
param = {
type = "table",
fields = {
{
name = "query",
description = "Query to search",
type = "string",
},
},
},
returns = {
{
name = "result",
description = "Result of the search",
type = "string",
},
{
name = "error",
description = "Error message if the search was not successful",
type = "string",
optional = true,
},
},
},
{
name = "python",
description = "Run python code",

302
lua/avante/rag_service.lua Normal file
View File

@@ -0,0 +1,302 @@
local curl = require("plenary.curl")
local Path = require("plenary.path")
local Utils = require("avante.utils")
local M = {}
local container_name = "avante-rag-service"
function M.get_rag_service_image() return "ghcr.io/yetone/avante-rag-service:0.0.3" end
function M.get_rag_service_port() return 20250 end
function M.get_rag_service_url() return string.format("http://localhost:%d", M.get_rag_service_port()) end
function M.get_data_path()
local p = Path:new(vim.fn.stdpath("data")):joinpath("avante/rag_service")
if not p:exists() then p:mkdir({ parents = true }) end
return p
end
function M.get_current_image()
local cmd = string.format("docker inspect %s | grep Image | grep %s", container_name, container_name)
local result = vim.fn.system(cmd)
if result == "" then return nil end
local exit_code = vim.v.shell_error
if exit_code ~= 0 then return nil end
local image = result:match('"Image":%s*"(.*)"')
if image == nil then return nil end
return image
end
---@return boolean already_running
function M.launch_rag_service()
local openai_api_key = os.getenv("OPENAI_API_KEY")
if openai_api_key == nil then
error("cannot launch avante rag service, OPENAI_API_KEY is not set")
return false
end
local openai_base_url = os.getenv("OPENAI_BASE_URL")
if openai_base_url == nil then openai_base_url = "https://api.openai.com/v1" end
local port = M.get_rag_service_port()
local image = M.get_rag_service_image()
local data_path = M.get_data_path()
local cmd = string.format("docker ps -a | grep '%s'", container_name)
local result = vim.fn.system(cmd)
if result ~= "" then
Utils.debug(string.format("container %s already running", container_name))
local current_image = M.get_current_image()
if current_image == image then return false end
Utils.debug(
string.format(
"container %s is running with different image: %s != %s, stopping...",
container_name,
current_image,
image
)
)
M.stop_rag_service()
else
Utils.debug(string.format("container %s not found, starting...", container_name))
end
local cmd_ = string.format(
"docker run -d -p %d:8000 --name %s -v %s:/data -v /:/host -e DATA_DIR=/data -e OPENAI_API_KEY=%s -e OPENAI_BASE_URL=%s %s",
port,
container_name,
data_path,
openai_api_key,
openai_base_url,
image
)
vim.fn.system(cmd_)
Utils.debug(string.format("container %s started", container_name))
return true
end
function M.stop_rag_service()
local cmd = string.format("docker ps -a | grep '%s'", container_name)
local result = vim.fn.system(cmd)
if result ~= "" then vim.fn.system(string.format("docker rm -fv %s", container_name)) end
end
function M.get_rag_service_status()
local cmd = string.format("docker ps -a | grep '%s'", container_name)
local result = vim.fn.system(cmd)
if result == "" then
return "running"
else
return "stopped"
end
end
function M.get_scheme(uri)
local scheme = uri:match("^(%w+)://")
if scheme == nil then return "unknown" end
return scheme
end
function M.to_container_uri(uri)
local scheme = M.get_scheme(uri)
if scheme == "file" then
local path = uri:match("^file://(.*)$")
uri = string.format("file:///host%s", path)
end
return uri
end
function M.to_local_uri(uri)
local scheme = M.get_scheme(uri)
if scheme == "file" then
local path = uri:match("^file://host(.*)$")
uri = string.format("file://%s", path)
end
return uri
end
function M.is_ready()
vim.fn.system(string.format("curl -s -o /dev/null -w '%%{http_code}' %s", M.get_rag_service_url()))
return vim.v.shell_error == 0
end
---@class AvanteRagServiceAddResourceResponse
---@field status string
---@field message string
---@param uri string
---@return AvanteRagServiceAddResourceResponse | nil
function M.add_resource(uri)
uri = M.to_container_uri(uri)
local resource_name = uri:match("([^/]+)/$")
local resources_resp = M.get_resources()
if resources_resp == nil then
Utils.error("Failed to get resources")
return nil
end
local already_added = false
for _, resource in ipairs(resources_resp.resources) do
if resource.uri == uri then
already_added = true
resource_name = resource.name
break
end
end
if not already_added then
local names_map = {}
for _, resource in ipairs(resources_resp.resources) do
names_map[resource.name] = true
end
if names_map[resource_name] then
for i = 1, 100 do
local resource_name_ = string.format("%s-%d", resource_name, i)
if not names_map[resource_name_] then
resource_name = resource_name_
break
end
end
if names_map[resource_name] then
Utils.error(string.format("Failed to add resource, name conflict: %s", resource_name))
return nil
end
end
end
local resp = curl.post(M.get_rag_service_url() .. "/api/v1/add_resource", {
headers = {
["Content-Type"] = "application/json",
},
body = vim.json.encode({
name = resource_name,
uri = uri,
}),
})
if resp.status ~= 200 then
Utils.error("failed to add resource: " .. resp.body)
return
end
return vim.json.decode(resp.body)
end
function M.remove_resource(uri)
uri = M.to_container_uri(uri)
local resp = curl.post(M.get_rag_service_url() .. "/api/v1/remove_resource", {
headers = {
["Content-Type"] = "application/json",
},
body = vim.json.encode({
uri = uri,
}),
})
if resp.status ~= 200 then
Utils.error("failed to remove resource: " .. resp.body)
return
end
return vim.json.decode(resp.body)
end
---@class AvanteRagServiceRetrieveSource
---@field uri string
---@field content string
---@class AvanteRagServiceRetrieveResponse
---@field response string
---@field sources AvanteRagServiceRetrieveSource[]
---@param base_uri string
---@param query string
---@return AvanteRagServiceRetrieveResponse | nil resp
---@return string | nil error
function M.retrieve(base_uri, query)
base_uri = M.to_container_uri(base_uri)
local resp = curl.post(M.get_rag_service_url() .. "/api/v1/retrieve", {
headers = {
["Content-Type"] = "application/json",
},
body = vim.json.encode({
base_uri = base_uri,
query = query,
top_k = 10,
}),
})
if resp.status ~= 200 then
Utils.error("failed to retrieve: " .. resp.body)
return nil, "failed to retrieve: " .. resp.body
end
local jsn = vim.json.decode(resp.body)
jsn.sources = vim
.iter(jsn.sources)
:map(function(source)
local uri = M.to_local_uri(source.uri)
return vim.tbl_deep_extend("force", source, { uri = uri })
end)
:totable()
return jsn, nil
end
---@class AvanteRagServiceIndexingStatusSummary
---@field indexing integer
---@field completed integer
---@field failed integer
---@class AvanteRagServiceIndexingStatusResponse
---@field uri string
---@field is_watched boolean
---@field total_files integer
---@field status_summary AvanteRagServiceIndexingStatusSummary
---@param uri string
---@return AvanteRagServiceIndexingStatusResponse | nil
function M.indexing_status(uri)
uri = M.to_container_uri(uri)
local resp = curl.post(M.get_rag_service_url() .. "/api/v1/indexing_status", {
headers = {
["Content-Type"] = "application/json",
},
body = vim.json.encode({
uri = uri,
}),
})
if resp.status ~= 200 then
Utils.error("Failed to get indexing status: " .. resp.body)
return
end
local jsn = vim.json.decode(resp.body)
jsn.uri = M.to_local_uri(jsn.uri)
return jsn
end
---@class AvanteRagServiceResource
---@field name string
---@field uri string
---@field type string
---@field status string
---@field indexing_status string
---@field created_at string
---@field indexing_started_at string | nil
---@field last_indexed_at string | nil
---@class AvanteRagServiceResourceListResponse
---@field resources AvanteRagServiceResource[]
---@field total_count number
---@return AvanteRagServiceResourceListResponse | nil
M.get_resources = function()
local resp = curl.get(M.get_rag_service_url() .. "/api/v1/resources", {
headers = {
["Content-Type"] = "application/json",
},
})
if resp.status ~= 200 then
Utils.error("Failed to get resources: " .. resp.body)
return
end
local jsn = vim.json.decode(resp.body)
jsn.resources = vim
.iter(jsn.resources)
:map(function(resource)
local uri = M.to_local_uri(resource.uri)
return vim.tbl_deep_extend("force", resource, { uri = uri })
end)
:totable()
return jsn
end
return M

View File

@@ -1227,7 +1227,7 @@ function Sidebar:apply(current_cursor)
if last_orig_diff_end_line > #original_code_lines then
pcall(function() api.nvim_win_set_cursor(winid, { #original_code_lines, 0 }) end)
else
api.nvim_win_set_cursor(winid, { last_orig_diff_end_line, 0 })
pcall(function() api.nvim_win_set_cursor(winid, { last_orig_diff_end_line, 0 }) end)
end
vim.cmd("normal! zz")
end,
@@ -2287,7 +2287,7 @@ function Sidebar:create_input_container(opts)
local chat_history = Path.history.load(self.code.bufnr)
local tools = vim.deepcopy(LLMTools.tools)
local tools = vim.deepcopy(LLMTools.get_tools())
table.insert(tools, {
name = "add_file_to_context",
description = "Add a file to the context",

View File

@@ -2,8 +2,10 @@ Don't directly search for code context in historical messages. Instead, prioriti
Tools Usage Guide:
- You have access to tools, but only use them when necessary. If a tool is not required, respond as normal.
- If you encounter a URL, prioritize using the fetch tool to obtain its content.
- If you have information that you don't know, please proactively use the tools provided by users! Especially the web search tool.
- If the `rag_search` tool exists, prioritize using it to do the search!
- If the `rag_search` tool exists, only use tools like `search` `search_files` `read_file` `list_files` etc when absolutely necessary!
- If you encounter a URL, prioritize using the `fetch` tool to obtain its content.
- If you have information that you don't know, please proactively use the tools provided by users! Especially the `web_search` tool.
- When available tools cannot meet the requirements, please try to use the `run_command` tool to solve the problem whenever possible.
- When attempting to modify a file that is not in the context, please first use the `list_files` tool and `search_files` tool to check if the file you want to modify exists, then use the `read_file` tool to read the file content. Don't modify blindly!
- When generating files, first use `list_files` tool to read the directory structure, don't generate blindly!

View File

@@ -329,6 +329,7 @@ vim.g.avante_login = vim.g.avante_login
---@field func? fun(input: any): (string | nil, string | nil)
---@field param AvanteLLMToolParam
---@field returns AvanteLLMToolReturn[]
---@field enabled? fun(): boolean
---@class AvanteLLMToolParam
---@field type string

33
py/rag-service/Dockerfile Normal file
View File

@@ -0,0 +1,33 @@
FROM debian:bookworm-slim
WORKDIR /app
RUN apt-get update && apt-get install -y \
build-essential \
curl \
&& rm -rf /var/lib/apt/lists/* \
&& curl -LsSf https://astral.sh/uv/install.sh | sh
ENV PATH="/root/.local/bin:$PATH"
RUN uv python install 3.11
RUN uv python list
ENV PATH="/root/.uv/python/3.11/bin:$PATH"
COPY requirements.txt .
RUN uv venv --python 3.11
RUN uv pip install -r requirements.txt
ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
PORT=8000
EXPOSE ${PORT}
COPY . .
CMD ["uv", "run", "fastapi", "run", "src/main.py", "--workers", "3"]

View File

@@ -0,0 +1,162 @@
aiohappyeyeballs==2.4.6
aiohttp==3.11.12
aiosignal==1.3.2
annotated-types==0.7.0
anyio==4.8.0
asgiref==3.8.1
asttokens==3.0.0
attrs==25.1.0
backoff==2.2.1
bcrypt==4.2.1
beautifulsoup4==4.13.3
build==1.2.2.post1
cachetools==5.5.1
certifi==2024.12.14
charset-normalizer==3.4.1
chroma-hnswlib==0.7.6
chromadb==0.6.3
click==8.1.8
coloredlogs==15.0.1
dataclasses-json==0.6.7
decorator==5.1.1
Deprecated==1.2.18
dirtyjson==1.0.8
distro==1.9.0
dnspython==2.7.0
durationpy==0.9
email_validator==2.2.0
executing==2.2.0
fastapi==0.115.8
fastapi-cli==0.0.7
filelock==3.17.0
filetype==1.2.0
flatbuffers==25.1.24
frozenlist==1.5.0
fsspec==2025.2.0
google-auth==2.38.0
googleapis-common-protos==1.66.0
greenlet==3.1.1
grpcio==1.70.0
h11==0.14.0
httpcore==1.0.7
httptools==0.6.4
httpx==0.28.1
huggingface-hub==0.28.1
humanfriendly==10.0
idna==3.10
importlib_metadata==8.5.0
importlib_resources==6.5.2
ipython==8.32.0
jedi==0.19.2
Jinja2==3.1.5
jiter==0.8.2
joblib==1.4.2
kubernetes==32.0.0
llama-cloud==0.1.11
llama-cloud-services==0.6.0
llama-index==0.12.16
llama-index-agent-openai==0.4.3
llama-index-cli==0.4.0
llama-index-core==0.12.16.post1
llama-index-embeddings-openai==0.3.1
llama-index-indices-managed-llama-cloud==0.6.4
llama-index-llms-openai==0.3.18
llama-index-multi-modal-llms-openai==0.4.3
llama-index-program-openai==0.3.1
llama-index-question-gen-openai==0.3.0
llama-index-readers-file==0.4.4
llama-index-readers-llama-parse==0.4.0
llama-index-vector-stores-chroma==0.4.1
llama-parse==0.6.0
markdown-it-py==3.0.0
markdownify==0.14.1
MarkupSafe==3.0.2
marshmallow==3.26.1
matplotlib-inline==0.1.7
mdurl==0.1.2
mmh3==5.1.0
monotonic==1.6
mpmath==1.3.0
multidict==6.1.0
mypy-extensions==1.0.0
nest-asyncio==1.6.0
networkx==3.4.2
nltk==3.9.1
numpy==2.2.2
oauthlib==3.2.2
onnxruntime==1.20.1
openai==1.61.1
opentelemetry-api==1.30.0
opentelemetry-exporter-otlp-proto-common==1.30.0
opentelemetry-exporter-otlp-proto-grpc==1.30.0
opentelemetry-instrumentation==0.51b0
opentelemetry-instrumentation-asgi==0.51b0
opentelemetry-instrumentation-fastapi==0.51b0
opentelemetry-proto==1.30.0
opentelemetry-sdk==1.30.0
opentelemetry-semantic-conventions==0.51b0
opentelemetry-util-http==0.51b0
orjson==3.10.15
overrides==7.7.0
packaging==24.2
pandas==2.2.3
parso==0.8.4
pathspec==0.12.1
pexpect==4.9.0
pillow==11.1.0
posthog==3.11.0
prompt_toolkit==3.0.50
propcache==0.2.1
protobuf==5.29.3
ptyprocess==0.7.0
pure_eval==0.2.3
pyasn1==0.6.1
pyasn1_modules==0.4.1
pydantic==2.10.6
pydantic_core==2.27.2
Pygments==2.19.1
pypdf==5.2.0
PyPika==0.48.9
pyproject_hooks==1.2.0
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-multipart==0.0.20
pytz==2025.1
PyYAML==6.0.2
regex==2024.11.6
requests==2.32.3
requests-oauthlib==2.0.0
rich==13.9.4
rich-toolkit==0.13.2
rsa==4.9
shellingham==1.5.4
six==1.17.0
sniffio==1.3.1
soupsieve==2.6
SQLAlchemy==2.0.38
stack-data==0.6.3
starlette==0.45.3
striprtf==0.0.26
sympy==1.13.3
tenacity==9.0.0
tiktoken==0.8.0
tokenizers==0.21.0
tqdm==4.67.1
traitlets==5.14.3
tree-sitter==0.21.3
tree-sitter-languages==1.10.2
typer==0.15.1
typing-inspect==0.9.0
typing_extensions==4.12.2
tzdata==2025.1
urllib3==2.3.0
uvicorn==0.34.0
uvloop==0.21.0
watchdog==6.0.0
watchfiles==1.0.4
wcwidth==0.2.13
websocket-client==1.8.0
websockets==14.2
wrapt==1.17.2
yarl==1.18.3
zipp==3.21.0

View File

View File

@@ -0,0 +1,14 @@
import os
from pathlib import Path
# Configuration
BASE_DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
CHROMA_PERSIST_DIR = BASE_DATA_DIR / "chroma_db"
LOG_DIR = BASE_DATA_DIR / "logs"
DB_FILE = BASE_DATA_DIR / "sqlite" / "indexing_history.db"
# Configure directories
BASE_DATA_DIR.mkdir(parents=True, exist_ok=True)
LOG_DIR.mkdir(parents=True, exist_ok=True)
DB_FILE.parent.mkdir(parents=True, exist_ok=True) # Create sqlite directory
CHROMA_PERSIST_DIR.mkdir(parents=True, exist_ok=True)

View File

@@ -0,0 +1,60 @@
import sqlite3
from collections.abc import Generator
from contextlib import contextmanager
from libs.configs import DB_FILE
# SQLite table schemas
CREATE_TABLES_SQL = """
CREATE TABLE IF NOT EXISTS indexing_history (
id INTEGER PRIMARY KEY AUTOINCREMENT,
uri TEXT NOT NULL,
content_hash TEXT NOT NULL,
status TEXT NOT NULL,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
error_message TEXT,
document_id TEXT,
metadata TEXT
);
CREATE INDEX IF NOT EXISTS idx_uri ON indexing_history(uri);
CREATE INDEX IF NOT EXISTS idx_document_id ON indexing_history(document_id);
CREATE INDEX IF NOT EXISTS idx_content_hash ON indexing_history(content_hash);
CREATE TABLE IF NOT EXISTS resources (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
uri TEXT NOT NULL UNIQUE,
type TEXT NOT NULL, -- 'path' or 'https'
status TEXT NOT NULL DEFAULT 'active', -- 'active' or 'inactive'
indexing_status TEXT NOT NULL DEFAULT 'pending', -- 'pending', 'indexing', 'indexed', 'failed'
indexing_status_message TEXT,
indexing_started_at DATETIME,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
last_indexed_at DATETIME,
last_error TEXT
);
CREATE INDEX IF NOT EXISTS idx_resources_name ON resources(name);
CREATE INDEX IF NOT EXISTS idx_resources_uri ON resources(uri);
CREATE INDEX IF NOT EXISTS idx_resources_status ON resources(status);
CREATE INDEX IF NOT EXISTS idx_status ON indexing_history(status);
"""
@contextmanager
def get_db_connection() -> Generator[sqlite3.Connection, None, None]:
"""Get a database connection."""
conn = sqlite3.connect(DB_FILE)
conn.row_factory = sqlite3.Row
try:
yield conn
finally:
conn.close()
def init_db() -> None:
"""Initialize the SQLite database."""
with get_db_connection() as conn:
conn.executescript(CREATE_TABLES_SQL)
conn.commit()

View File

@@ -0,0 +1,16 @@
import logging
from datetime import datetime
from libs.configs import LOG_DIR
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(
LOG_DIR / f"rag_service_{datetime.now().astimezone().strftime('%Y%m%d')}.log",
),
logging.StreamHandler(),
],
)
logger = logging.getLogger(__name__)

View File

@@ -0,0 +1,66 @@
from __future__ import annotations
import re
from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from llama_index.core.schema import BaseNode
PATTERN_URI_PART = re.compile(r"(?P<uri>.+)__part_\d+")
METADATA_KEY_URI = "uri"
def uri_to_path(uri: str) -> Path:
"""Convert URI to path."""
return Path(uri.replace("file://", ""))
def path_to_uri(file_path: Path) -> str:
"""Convert path to URI."""
uri = file_path.as_uri()
if file_path.is_dir():
uri += "/"
return uri
def is_local_uri(uri: str) -> bool:
"""Check if the URI is a path URI."""
return uri.startswith("file://")
def is_remote_uri(uri: str) -> bool:
"""Check if the URI is an HTTPS URI or HTTP URI."""
return uri.startswith(("https://", "http://"))
def is_path_node(node: BaseNode) -> bool:
"""Check if the node is a file node."""
uri = get_node_uri(node)
if not uri:
return False
return is_local_uri(uri)
def get_node_uri(node: BaseNode) -> str | None:
"""Get URI from node metadata."""
uri = node.metadata.get(METADATA_KEY_URI)
if not uri:
doc_id = getattr(node, "doc_id", None)
if doc_id:
match = PATTERN_URI_PART.match(doc_id)
uri = match.group("uri") if match else doc_id
if uri:
if uri.startswith("/"):
uri = f"file://{uri}"
return uri
return None
def inject_uri_to_node(node: BaseNode) -> None:
"""Inject file path into node metadata."""
if METADATA_KEY_URI in node.metadata:
return
uri = get_node_uri(node)
if uri:
node.metadata[METADATA_KEY_URI] = uri

1114
py/rag-service/src/main.py Normal file

File diff suppressed because it is too large Load Diff

View File

View File

@@ -0,0 +1,19 @@
"""Indexing History Model."""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
class IndexingHistory(BaseModel):
"""Model for indexing history record."""
id: int | None = Field(None, description="Record ID")
uri: str = Field(..., description="URI of the indexed file")
content_hash: str = Field(..., description="MD5 hash of the file content")
status: str = Field(..., description="Indexing status (indexing/completed/failed)")
timestamp: datetime = Field(default_factory=datetime.now, description="Record timestamp")
error_message: str | None = Field(None, description="Error message if failed")
document_id: str | None = Field(None, description="Document ID in the index")
metadata: dict[str, Any] | None = Field(None, description="Additional metadata")

View File

@@ -0,0 +1,25 @@
"""Resource Model."""
from datetime import datetime
from typing import Literal
from pydantic import BaseModel, Field
class Resource(BaseModel):
"""Model for resource record."""
id: int | None = Field(None, description="Resource ID")
name: str = Field(..., description="Name of the resource")
uri: str = Field(..., description="URI of the resource")
type: Literal["local", "remote"] = Field(..., description="Type of resource (path/https)")
status: str = Field("active", description="Status of resource (active/inactive)")
indexing_status: Literal["pending", "indexing", "indexed", "failed"] = Field(
"pending",
description="Indexing status (pending/indexing/indexed/failed)",
)
indexing_status_message: str | None = Field(None, description="Indexing status message")
created_at: datetime = Field(default_factory=datetime.now, description="Creation timestamp")
indexing_started_at: datetime | None = Field(None, description="Indexing start timestamp")
last_indexed_at: datetime | None = Field(None, description="Last indexing timestamp")
last_error: str | None = Field(None, description="Last error message if any")

View File

View File

@@ -0,0 +1,174 @@
import json
import os
from datetime import datetime
from typing import Any
from libs.db import get_db_connection
from libs.logger import logger
from libs.utils import get_node_uri
from llama_index.core.schema import Document
from models.indexing_history import IndexingHistory
class IndexingHistoryService:
def delete_indexing_status(self, uri: str) -> None:
"""Delete indexing status for a specific file."""
with get_db_connection() as conn:
conn.execute(
"""
DELETE FROM indexing_history
WHERE uri = ?
""",
(uri,),
)
conn.commit()
def delete_indexing_status_by_document_id(self, document_id: str) -> None:
"""Delete indexing status for a specific document."""
with get_db_connection() as conn:
conn.execute(
"""
DELETE FROM indexing_history
WHERE document_id = ?
""",
(document_id,),
)
conn.commit()
def update_indexing_status(
self,
doc: Document,
status: str,
error_message: str | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""Update the indexing status in the database."""
content_hash = doc.hash
# Get URI from metadata if available
uri = get_node_uri(doc)
if not uri:
logger.warning("URI not found for document: %s", doc.doc_id)
return
record = IndexingHistory(
id=None,
uri=uri,
content_hash=content_hash,
status=status,
error_message=error_message,
document_id=doc.doc_id,
metadata=metadata,
)
with get_db_connection() as conn:
# Check if record exists
existing = conn.execute(
"SELECT id FROM indexing_history WHERE document_id = ?",
(doc.doc_id,),
).fetchone()
if existing:
# Update existing record
conn.execute(
"""
UPDATE indexing_history
SET content_hash = ?, status = ?, error_message = ?, document_id = ?, metadata = ?
WHERE uri = ?
""",
(
record.content_hash,
record.status,
record.error_message,
record.document_id,
json.dumps(record.metadata) if record.metadata else None,
record.uri,
),
)
else:
# Insert new record
conn.execute(
"""
INSERT INTO indexing_history
(uri, content_hash, status, error_message, document_id, metadata)
VALUES (?, ?, ?, ?, ?, ?)
""",
(
record.uri,
record.content_hash,
record.status,
record.error_message,
record.document_id,
json.dumps(record.metadata) if record.metadata else None,
),
)
conn.commit()
def get_indexing_status(self, doc: Document | None = None, base_uri: str | None = None) -> list[IndexingHistory]:
"""Get indexing status from the database."""
with get_db_connection() as conn:
if doc:
uri = get_node_uri(doc)
if not uri:
logger.warning("URI not found for document: %s", doc.doc_id)
return []
content_hash = doc.hash
# For a specific file, get its latest status
query = """
SELECT *
FROM indexing_history
WHERE uri = ? and content_hash = ?
ORDER BY timestamp DESC LIMIT 1
"""
params = (uri, content_hash)
elif base_uri:
# For files in a specific directory, get their latest status
query = """
WITH RankedHistory AS (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY document_id ORDER BY timestamp DESC) as rn
FROM indexing_history
WHERE uri LIKE ? || '%'
)
SELECT id, uri, content_hash, status, timestamp, error_message, document_id, metadata
FROM RankedHistory
WHERE rn = 1
ORDER BY timestamp DESC
"""
params = (base_uri,) if base_uri.endswith(os.path.sep) else (base_uri + os.path.sep,)
else:
# For all files, get their latest status
query = """
WITH RankedHistory AS (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY uri ORDER BY timestamp DESC) as rn
FROM indexing_history
)
SELECT id, uri, content_hash, status, timestamp, error_message, document_id, metadata
FROM RankedHistory
WHERE rn = 1
ORDER BY timestamp DESC
"""
params = ()
rows = conn.execute(query, params).fetchall()
result = []
for row in rows:
row_dict = dict(row)
# Parse metadata JSON if it exists
if row_dict.get("metadata"):
try:
row_dict["metadata"] = json.loads(row_dict["metadata"])
except json.JSONDecodeError:
row_dict["metadata"] = None
# Parse timestamp string to datetime if needed
if isinstance(row_dict.get("timestamp"), str):
row_dict["timestamp"] = datetime.fromisoformat(
row_dict["timestamp"].replace("Z", "+00:00"),
)
result.append(IndexingHistory(**row_dict))
return result
indexing_history_service = IndexingHistoryService()

View File

@@ -0,0 +1,104 @@
"""Resource Service."""
from libs.db import get_db_connection
from models.resource import Resource
class ResourceService:
"""Resource Service."""
def add_resource_to_db(self, resource: Resource) -> None:
"""Add a resource to the database."""
with get_db_connection() as conn:
conn.execute(
"""
INSERT INTO resources (name, uri, type, status, indexing_status, created_at)
VALUES (?, ?, ?, ?, ?, ?)
""",
(
resource.name,
resource.uri,
resource.type,
resource.status,
resource.indexing_status,
resource.created_at,
),
)
conn.commit()
def update_resource_indexing_status(self, uri: str, indexing_status: str, indexing_status_message: str) -> None:
"""Update resource indexing status in the database."""
with get_db_connection() as conn:
if indexing_status == "indexing":
conn.execute(
"""
UPDATE resources
SET indexing_status = ?, indexing_status_message = ?, indexing_started_at = CURRENT_TIMESTAMP
WHERE uri = ?
""",
(indexing_status, indexing_status_message, uri),
)
else:
conn.execute(
"""
UPDATE resources
SET indexing_status = ?, indexing_status_message = ?, last_indexed_at = CURRENT_TIMESTAMP
WHERE uri = ?
""",
(indexing_status, indexing_status_message, uri),
)
conn.commit()
def update_resource_status(self, uri: str, status: str, error: str | None = None) -> None:
"""Update resource status in the database."""
with get_db_connection() as conn:
if status == "active":
conn.execute(
"""
UPDATE resources
SET status = ?, last_indexed_at = CURRENT_TIMESTAMP, last_error = ?
WHERE uri = ?
""",
(status, error, uri),
)
else:
conn.execute(
"""
UPDATE resources
SET status = ?, last_error = ?
WHERE uri = ?
""",
(status, error, uri),
)
conn.commit()
def get_resource(self, uri: str) -> Resource | None:
"""Get resource from the database."""
with get_db_connection() as conn:
row = conn.execute(
"SELECT * FROM resources WHERE uri = ?",
(uri,),
).fetchone()
if row:
return Resource(**dict(row))
return None
def get_resource_by_name(self, name: str) -> Resource | None:
"""Get resource by name from the database."""
with get_db_connection() as conn:
row = conn.execute(
"SELECT * FROM resources WHERE name = ?",
(name,),
).fetchone()
if row:
return Resource(**dict(row))
return None
def get_all_resources(self) -> list[Resource]:
"""Get all resources from the database."""
with get_db_connection() as conn:
rows = conn.execute("SELECT * FROM resources ORDER BY created_at DESC").fetchall()
return [Resource(**dict(row)) for row in rows]
resource_service = ResourceService()

25
pyrightconfig.json Normal file
View File

@@ -0,0 +1,25 @@
{
"include": [
"."
],
"exclude": [
"**/node_modules",
"**/__pycache__",
"**/.*"
],
"defineConstant": {
"DEBUG": true
},
"venvPath": ".",
"venv": ".venv",
"pythonVersion": "3.11",
"typeCheckingMode": "strict",
"reportMissingImports": true,
"reportMissingTypeStubs": false,
"reportUnknownMemberType": false,
"reportUnknownParameterType": false,
"reportUnknownVariableType": false,
"reportUnknownArgumentType": false,
"reportPrivateUsage": false,
"reportUntypedFunctionDecorator": false
}

49
ruff.toml Normal file
View File

@@ -0,0 +1,49 @@
# 与 black 保持一致的行长度
line-length = 180
# 排除一些目录
exclude = [
".git",
".ruff_cache",
".venv",
"venv",
"__pycache__",
"build",
"dist",
]
# 目标 Python 版本
target-version = "py312"
[lint]
# 启用所有规则集
select = ["ALL"]
# 忽略一些规则
ignore = [
"A005",
"BLE001",
"D104",
"D100",
"D101",
"D203", # 1 blank line required before class docstring
"D212", # Multi-line docstring summary should start at the first line
"TRY300",
"TRY400",
"PGH003",
"PLR0911",
]
# 允许使用自动修复
fixable = ["ALL"]
[format]
# 使用双引号
quote-style = "double"
# 缩进风格
indent-style = "space"
[lint.isort]
# 与 black 兼容的导入排序设置
combine-as-imports = true
known-first-party = ["avante"]