From fd84c91cdbfee035d430fa7bd79ed66d4be8f53d Mon Sep 17 00:00:00 2001 From: yetone Date: Sun, 23 Feb 2025 01:37:26 +0800 Subject: [PATCH] feat: RAG service (#1220) --- .editorconfig | 1 - .github/workflows/python.yaml | 31 + .gitignore | 12 +- .pre-commit-config.yaml | 24 +- LICENSE | 2 +- Makefile | 5 + README.md | 12 + build.sh | 1 - .../queries/tree-sitter-c-sharp-defs.scm | 1 - lua/avante/config.lua | 3 + lua/avante/init.lua | 36 + lua/avante/llm_tools.lua | 52 +- lua/avante/rag_service.lua | 302 +++++ lua/avante/sidebar.lua | 4 +- .../templates/_tools-guidelines.avanterules | 6 +- lua/avante/types.lua | 1 + py/rag-service/Dockerfile | 33 + py/rag-service/requirements.txt | 162 +++ py/rag-service/src/libs/__init__.py | 0 py/rag-service/src/libs/configs.py | 14 + py/rag-service/src/libs/db.py | 60 + py/rag-service/src/libs/logger.py | 16 + py/rag-service/src/libs/utils.py | 66 + py/rag-service/src/main.py | 1114 +++++++++++++++++ py/rag-service/src/models/__init__.py | 0 py/rag-service/src/models/indexing_history.py | 19 + py/rag-service/src/models/resource.py | 25 + py/rag-service/src/services/__init__.py | 0 .../src/services/indexing_history.py | 174 +++ py/rag-service/src/services/resource.py | 104 ++ pyrightconfig.json | 25 + ruff.toml | 49 + 32 files changed, 2339 insertions(+), 15 deletions(-) create mode 100644 .github/workflows/python.yaml create mode 100644 lua/avante/rag_service.lua create mode 100644 py/rag-service/Dockerfile create mode 100644 py/rag-service/requirements.txt create mode 100644 py/rag-service/src/libs/__init__.py create mode 100644 py/rag-service/src/libs/configs.py create mode 100644 py/rag-service/src/libs/db.py create mode 100644 py/rag-service/src/libs/logger.py create mode 100644 py/rag-service/src/libs/utils.py create mode 100644 py/rag-service/src/main.py create mode 100644 py/rag-service/src/models/__init__.py create mode 100644 py/rag-service/src/models/indexing_history.py create mode 100644 py/rag-service/src/models/resource.py create mode 100644 py/rag-service/src/services/__init__.py create mode 100644 py/rag-service/src/services/indexing_history.py create mode 100644 py/rag-service/src/services/resource.py create mode 100644 pyrightconfig.json create mode 100644 ruff.toml diff --git a/.editorconfig b/.editorconfig index 16907cb..275e3e1 100644 --- a/.editorconfig +++ b/.editorconfig @@ -13,4 +13,3 @@ tab_width = 8 [{Makefile,**/Makefile}] indent_style = tab indent_size = 8 - diff --git a/.github/workflows/python.yaml b/.github/workflows/python.yaml new file mode 100644 index 0000000..ff4dea8 --- /dev/null +++ b/.github/workflows/python.yaml @@ -0,0 +1,31 @@ +name: Python + +on: + push: + branches: + - main + paths: + - '**/*.py' + - '.github/workflows/python.yaml' + pull_request: + branches: + - main + paths: + - '**/*.py' + - '.github/workflows/python.yaml' + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v3 + with: + python-version: '3.11' + - run: | + python -m venv .venv + source .venv/bin/activate + pip install -r py/rag-service/requirements.txt + - uses: pre-commit/action@v3.0.1 + with: + extra_args: --files $(find ./py -type f -name "*.py") diff --git a/.gitignore b/.gitignore index fd5a47b..7d040a1 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,10 @@ *.lua~ *.luac +.venv +__pycache__/ +data/ + # Neovim plugin specific files plugin/packer_compiled.lua @@ -31,14 +35,14 @@ temp/ .env # If you use any build tools, you might need to ignore build output directories -/build/ -/dist/ +build/ +dist/ # If you use any test frameworks, you might need to ignore test coverage reports -/coverage/ +coverage/ # If you use documentation generation tools, you might need to ignore generated docs -/doc/ +doc/ # If you have any personal configuration files, you should ignore them too config.personal.lua diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2a15022..e448f93 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,14 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.3.0 + rev: v5.0.0 hooks: - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace + - id: check-ast # 检查Python语法错误 + - id: debug-statements # 检查是否有debug语句 - repo: https://github.com/JohnnyMorganz/StyLua - rev: v0.20.0 + rev: v2.0.2 hooks: - id: stylua-system # or stylua-system / stylua-github files: \.lua$ @@ -15,3 +17,21 @@ repos: hooks: - id: fmt files: \.rs$ + - id: cargo-check + args: ['--features', 'luajit'] + files: \.rs$ +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.9.5 + hooks: + # 运行 Ruff linter + - id: ruff + args: [--fix] + # 运行 Ruff formatter + - id: ruff-format +- repo: https://github.com/RobertCraigie/pyright-python + rev: v1.1.393 + hooks: + - id: pyright + additional_dependencies: + - "types-setuptools" + - "types-requests" diff --git a/LICENSE b/LICENSE index f49a4e1..261eeb9 100644 --- a/LICENSE +++ b/LICENSE @@ -198,4 +198,4 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and - limitations under the License. \ No newline at end of file + limitations under the License. diff --git a/Makefile b/Makefile index 08b7e57..da6ffaf 100644 --- a/Makefile +++ b/Makefile @@ -100,3 +100,8 @@ lint: luacheck luastylecheck ruststylecheck rustlint .PHONY: lua-typecheck lua-typecheck: bash ./scripts/lua-typecheck.sh + +.PHONY: build-image +build-image: + docker build -t ghcr.io/yetone/avante-rag-service:0.0.3 -f py/rag-service/Dockerfile py/rag-service + docker push ghcr.io/yetone/avante-rag-service:0.0.3 diff --git a/README.md b/README.md index ccd44cf..81ac216 100644 --- a/README.md +++ b/README.md @@ -622,6 +622,18 @@ Because avante.nvim has always used Aider’s method for planning applying, but Therefore, I have adopted Cursor’s method to implement planning applying. For details on the implementation, please refer to [cursor-planning-mode.md](./cursor-planning-mode.md) +## RAG Service + +Avante provides a RAG service, which is a tool for obtaining the required context for the AI to generate the codes. Default it not enabled, you can enable it in this way: + +```lua +rag_service = { + enabled = true, -- Enables the rag service, requires OPENAI_API_KEY to be set +}, +``` + +Please note that since the RAG service uses OpenAI for embeddings, you must set `OPENAI_API_KEY` environment variable! + ## Web Search Engines Avante's tools include some web search engines, currently support: diff --git a/build.sh b/build.sh index 570b970..6c79c9a 100644 --- a/build.sh +++ b/build.sh @@ -78,4 +78,3 @@ else curl -L "$ARTIFACT_URL" | tar -zxv -C "$TARGET_DIR" fi - diff --git a/crates/avante-repo-map/queries/tree-sitter-c-sharp-defs.scm b/crates/avante-repo-map/queries/tree-sitter-c-sharp-defs.scm index 84afcb8..b8a39bf 100644 --- a/crates/avante-repo-map/queries/tree-sitter-c-sharp-defs.scm +++ b/crates/avante-repo-map/queries/tree-sitter-c-sharp-defs.scm @@ -22,4 +22,3 @@ (enum_declaration body: (enum_member_declaration_list (enum_member_declaration) @enum_item)) - diff --git a/lua/avante/config.lua b/lua/avante/config.lua index a370aa2..75a7d35 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -32,6 +32,9 @@ M._defaults = { -- For most providers that we support we will determine this automatically. -- If you wish to use a given implementation, then you can override it here. tokenizer = "tiktoken", + rag_service = { + enabled = false, -- Enables the rag service, requires OPENAI_API_KEY to be set + }, web_search_engine = { provider = "tavily", providers = { diff --git a/lua/avante/init.lua b/lua/avante/init.lua index 25e9d2a..0547286 100644 --- a/lua/avante/init.lua +++ b/lua/avante/init.lua @@ -6,6 +6,7 @@ local Selection = require("avante.selection") local Suggestion = require("avante.suggestion") local Config = require("avante.config") local Diff = require("avante.diff") +local RagService = require("avante.rag_service") ---@class Avante local M = { @@ -383,6 +384,41 @@ function M.setup(opts) H.signs() M.did_setup = true + + local function run_rag_service() + local started_at = os.time() + local add_resource_with_delay + local function add_resource() + local is_ready = RagService.is_ready() + if not is_ready then + local elapsed = os.time() - started_at + if elapsed > 1000 * 60 * 15 then + Utils.warn("Rag Service is not ready, giving up") + return + end + add_resource_with_delay() + return + end + vim.defer_fn(function() + Utils.info("Adding project root to Rag Service ...") + local uri = "file://" .. Utils.get_project_root() + if uri:sub(-1) ~= "/" then uri = uri .. "/" end + RagService.add_resource(uri) + Utils.info("Added project root to Rag Service") + end, 5000) + end + add_resource_with_delay = function() + vim.defer_fn(function() add_resource() end, 5000) + end + vim.schedule(function() + Utils.info("Starting Rag Service ...") + RagService.launch_rag_service() + Utils.info("Launched Rag Service") + add_resource_with_delay() + end) + end + + if Config.rag_service.enabled then run_rag_service() end end return M diff --git a/lua/avante/llm_tools.lua b/lua/avante/llm_tools.lua index d4f4069..e6b183c 100644 --- a/lua/avante/llm_tools.lua +++ b/lua/avante/llm_tools.lua @@ -2,6 +2,9 @@ local curl = require("plenary.curl") local Utils = require("avante.utils") local Path = require("plenary.path") local Config = require("avante.config") +local RagService = require("avante.rag_service") + +---@class AvanteRagService local M = {} ---@param rel_path string @@ -533,6 +536,22 @@ function M.git_commit(opts, on_log) return true, nil end +---@param opts { query: string } +---@param on_log? fun(log: string): nil +---@return string|nil result +---@return string|nil error +function M.rag_search(opts, on_log) + if not Config.rag_service.enabled then return nil, "Rag service is not enabled" end + if not opts.query then return nil, "No query provided" end + if on_log then on_log("query: " .. opts.query) end + local root = Utils.get_project_root() + local uri = "file://" .. root + if uri:sub(-1) ~= "/" then uri = uri .. "/" end + local resp, err = RagService.retrieve(uri, opts.query) + if err then return nil, err end + return vim.json.encode(resp), nil +end + ---@param opts { code: string, rel_path: string } ---@param on_log? fun(log: string): nil ---@return string|nil result @@ -554,8 +573,39 @@ function M.python(opts, on_log) return output, nil end +---@return AvanteLLMTool[] +function M.get_tools() return M._tools end + ---@type AvanteLLMTool[] -M.tools = { +M._tools = { + { + name = "rag_search", + enabled = function() return Config.rag_service.enabled and RagService.is_ready() end, + description = "Use Retrieval-Augmented Generation (RAG) to search for relevant information from an external knowledge base or documents. This tool retrieves relevant context from a large dataset and integrates it into the response generation process, improving accuracy and relevance. Use it when answering questions that require factual knowledge beyond what the model has been trained on.", + param = { + type = "table", + fields = { + { + name = "query", + description = "Query to search", + type = "string", + }, + }, + }, + returns = { + { + name = "result", + description = "Result of the search", + type = "string", + }, + { + name = "error", + description = "Error message if the search was not successful", + type = "string", + optional = true, + }, + }, + }, { name = "python", description = "Run python code", diff --git a/lua/avante/rag_service.lua b/lua/avante/rag_service.lua new file mode 100644 index 0000000..b59d05e --- /dev/null +++ b/lua/avante/rag_service.lua @@ -0,0 +1,302 @@ +local curl = require("plenary.curl") +local Path = require("plenary.path") +local Utils = require("avante.utils") + +local M = {} + +local container_name = "avante-rag-service" + +function M.get_rag_service_image() return "ghcr.io/yetone/avante-rag-service:0.0.3" end + +function M.get_rag_service_port() return 20250 end + +function M.get_rag_service_url() return string.format("http://localhost:%d", M.get_rag_service_port()) end + +function M.get_data_path() + local p = Path:new(vim.fn.stdpath("data")):joinpath("avante/rag_service") + if not p:exists() then p:mkdir({ parents = true }) end + return p +end + +function M.get_current_image() + local cmd = string.format("docker inspect %s | grep Image | grep %s", container_name, container_name) + local result = vim.fn.system(cmd) + if result == "" then return nil end + local exit_code = vim.v.shell_error + if exit_code ~= 0 then return nil end + local image = result:match('"Image":%s*"(.*)"') + if image == nil then return nil end + return image +end + +---@return boolean already_running +function M.launch_rag_service() + local openai_api_key = os.getenv("OPENAI_API_KEY") + if openai_api_key == nil then + error("cannot launch avante rag service, OPENAI_API_KEY is not set") + return false + end + local openai_base_url = os.getenv("OPENAI_BASE_URL") + if openai_base_url == nil then openai_base_url = "https://api.openai.com/v1" end + local port = M.get_rag_service_port() + local image = M.get_rag_service_image() + local data_path = M.get_data_path() + local cmd = string.format("docker ps -a | grep '%s'", container_name) + local result = vim.fn.system(cmd) + if result ~= "" then + Utils.debug(string.format("container %s already running", container_name)) + local current_image = M.get_current_image() + if current_image == image then return false end + Utils.debug( + string.format( + "container %s is running with different image: %s != %s, stopping...", + container_name, + current_image, + image + ) + ) + M.stop_rag_service() + else + Utils.debug(string.format("container %s not found, starting...", container_name)) + end + local cmd_ = string.format( + "docker run -d -p %d:8000 --name %s -v %s:/data -v /:/host -e DATA_DIR=/data -e OPENAI_API_KEY=%s -e OPENAI_BASE_URL=%s %s", + port, + container_name, + data_path, + openai_api_key, + openai_base_url, + image + ) + vim.fn.system(cmd_) + Utils.debug(string.format("container %s started", container_name)) + return true +end + +function M.stop_rag_service() + local cmd = string.format("docker ps -a | grep '%s'", container_name) + local result = vim.fn.system(cmd) + if result ~= "" then vim.fn.system(string.format("docker rm -fv %s", container_name)) end +end + +function M.get_rag_service_status() + local cmd = string.format("docker ps -a | grep '%s'", container_name) + local result = vim.fn.system(cmd) + if result == "" then + return "running" + else + return "stopped" + end +end + +function M.get_scheme(uri) + local scheme = uri:match("^(%w+)://") + if scheme == nil then return "unknown" end + return scheme +end + +function M.to_container_uri(uri) + local scheme = M.get_scheme(uri) + if scheme == "file" then + local path = uri:match("^file://(.*)$") + uri = string.format("file:///host%s", path) + end + return uri +end + +function M.to_local_uri(uri) + local scheme = M.get_scheme(uri) + if scheme == "file" then + local path = uri:match("^file://host(.*)$") + uri = string.format("file://%s", path) + end + return uri +end + +function M.is_ready() + vim.fn.system(string.format("curl -s -o /dev/null -w '%%{http_code}' %s", M.get_rag_service_url())) + return vim.v.shell_error == 0 +end + +---@class AvanteRagServiceAddResourceResponse +---@field status string +---@field message string + +---@param uri string +---@return AvanteRagServiceAddResourceResponse | nil +function M.add_resource(uri) + uri = M.to_container_uri(uri) + local resource_name = uri:match("([^/]+)/$") + local resources_resp = M.get_resources() + if resources_resp == nil then + Utils.error("Failed to get resources") + return nil + end + local already_added = false + for _, resource in ipairs(resources_resp.resources) do + if resource.uri == uri then + already_added = true + resource_name = resource.name + break + end + end + if not already_added then + local names_map = {} + for _, resource in ipairs(resources_resp.resources) do + names_map[resource.name] = true + end + if names_map[resource_name] then + for i = 1, 100 do + local resource_name_ = string.format("%s-%d", resource_name, i) + if not names_map[resource_name_] then + resource_name = resource_name_ + break + end + end + if names_map[resource_name] then + Utils.error(string.format("Failed to add resource, name conflict: %s", resource_name)) + return nil + end + end + end + local resp = curl.post(M.get_rag_service_url() .. "/api/v1/add_resource", { + headers = { + ["Content-Type"] = "application/json", + }, + body = vim.json.encode({ + name = resource_name, + uri = uri, + }), + }) + if resp.status ~= 200 then + Utils.error("failed to add resource: " .. resp.body) + return + end + return vim.json.decode(resp.body) +end + +function M.remove_resource(uri) + uri = M.to_container_uri(uri) + local resp = curl.post(M.get_rag_service_url() .. "/api/v1/remove_resource", { + headers = { + ["Content-Type"] = "application/json", + }, + body = vim.json.encode({ + uri = uri, + }), + }) + if resp.status ~= 200 then + Utils.error("failed to remove resource: " .. resp.body) + return + end + return vim.json.decode(resp.body) +end + +---@class AvanteRagServiceRetrieveSource +---@field uri string +---@field content string + +---@class AvanteRagServiceRetrieveResponse +---@field response string +---@field sources AvanteRagServiceRetrieveSource[] + +---@param base_uri string +---@param query string +---@return AvanteRagServiceRetrieveResponse | nil resp +---@return string | nil error +function M.retrieve(base_uri, query) + base_uri = M.to_container_uri(base_uri) + local resp = curl.post(M.get_rag_service_url() .. "/api/v1/retrieve", { + headers = { + ["Content-Type"] = "application/json", + }, + body = vim.json.encode({ + base_uri = base_uri, + query = query, + top_k = 10, + }), + }) + if resp.status ~= 200 then + Utils.error("failed to retrieve: " .. resp.body) + return nil, "failed to retrieve: " .. resp.body + end + local jsn = vim.json.decode(resp.body) + jsn.sources = vim + .iter(jsn.sources) + :map(function(source) + local uri = M.to_local_uri(source.uri) + return vim.tbl_deep_extend("force", source, { uri = uri }) + end) + :totable() + return jsn, nil +end + +---@class AvanteRagServiceIndexingStatusSummary +---@field indexing integer +---@field completed integer +---@field failed integer + +---@class AvanteRagServiceIndexingStatusResponse +---@field uri string +---@field is_watched boolean +---@field total_files integer +---@field status_summary AvanteRagServiceIndexingStatusSummary + +---@param uri string +---@return AvanteRagServiceIndexingStatusResponse | nil +function M.indexing_status(uri) + uri = M.to_container_uri(uri) + local resp = curl.post(M.get_rag_service_url() .. "/api/v1/indexing_status", { + headers = { + ["Content-Type"] = "application/json", + }, + body = vim.json.encode({ + uri = uri, + }), + }) + if resp.status ~= 200 then + Utils.error("Failed to get indexing status: " .. resp.body) + return + end + local jsn = vim.json.decode(resp.body) + jsn.uri = M.to_local_uri(jsn.uri) + return jsn +end + +---@class AvanteRagServiceResource +---@field name string +---@field uri string +---@field type string +---@field status string +---@field indexing_status string +---@field created_at string +---@field indexing_started_at string | nil +---@field last_indexed_at string | nil + +---@class AvanteRagServiceResourceListResponse +---@field resources AvanteRagServiceResource[] +---@field total_count number + +---@return AvanteRagServiceResourceListResponse | nil +M.get_resources = function() + local resp = curl.get(M.get_rag_service_url() .. "/api/v1/resources", { + headers = { + ["Content-Type"] = "application/json", + }, + }) + if resp.status ~= 200 then + Utils.error("Failed to get resources: " .. resp.body) + return + end + local jsn = vim.json.decode(resp.body) + jsn.resources = vim + .iter(jsn.resources) + :map(function(resource) + local uri = M.to_local_uri(resource.uri) + return vim.tbl_deep_extend("force", resource, { uri = uri }) + end) + :totable() + return jsn +end + +return M diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 153e74f..9696ac5 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -1227,7 +1227,7 @@ function Sidebar:apply(current_cursor) if last_orig_diff_end_line > #original_code_lines then pcall(function() api.nvim_win_set_cursor(winid, { #original_code_lines, 0 }) end) else - api.nvim_win_set_cursor(winid, { last_orig_diff_end_line, 0 }) + pcall(function() api.nvim_win_set_cursor(winid, { last_orig_diff_end_line, 0 }) end) end vim.cmd("normal! zz") end, @@ -2287,7 +2287,7 @@ function Sidebar:create_input_container(opts) local chat_history = Path.history.load(self.code.bufnr) - local tools = vim.deepcopy(LLMTools.tools) + local tools = vim.deepcopy(LLMTools.get_tools()) table.insert(tools, { name = "add_file_to_context", description = "Add a file to the context", diff --git a/lua/avante/templates/_tools-guidelines.avanterules b/lua/avante/templates/_tools-guidelines.avanterules index 58ff6e2..384ca2e 100644 --- a/lua/avante/templates/_tools-guidelines.avanterules +++ b/lua/avante/templates/_tools-guidelines.avanterules @@ -2,8 +2,10 @@ Don't directly search for code context in historical messages. Instead, prioriti Tools Usage Guide: - You have access to tools, but only use them when necessary. If a tool is not required, respond as normal. - - If you encounter a URL, prioritize using the fetch tool to obtain its content. - - If you have information that you don't know, please proactively use the tools provided by users! Especially the web search tool. + - If the `rag_search` tool exists, prioritize using it to do the search! + - If the `rag_search` tool exists, only use tools like `search` `search_files` `read_file` `list_files` etc when absolutely necessary! + - If you encounter a URL, prioritize using the `fetch` tool to obtain its content. + - If you have information that you don't know, please proactively use the tools provided by users! Especially the `web_search` tool. - When available tools cannot meet the requirements, please try to use the `run_command` tool to solve the problem whenever possible. - When attempting to modify a file that is not in the context, please first use the `list_files` tool and `search_files` tool to check if the file you want to modify exists, then use the `read_file` tool to read the file content. Don't modify blindly! - When generating files, first use `list_files` tool to read the directory structure, don't generate blindly! diff --git a/lua/avante/types.lua b/lua/avante/types.lua index 9599523..9034e63 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -329,6 +329,7 @@ vim.g.avante_login = vim.g.avante_login ---@field func? fun(input: any): (string | nil, string | nil) ---@field param AvanteLLMToolParam ---@field returns AvanteLLMToolReturn[] +---@field enabled? fun(): boolean ---@class AvanteLLMToolParam ---@field type string diff --git a/py/rag-service/Dockerfile b/py/rag-service/Dockerfile new file mode 100644 index 0000000..2dfe805 --- /dev/null +++ b/py/rag-service/Dockerfile @@ -0,0 +1,33 @@ +FROM debian:bookworm-slim + +WORKDIR /app + +RUN apt-get update && apt-get install -y \ + build-essential \ + curl \ + && rm -rf /var/lib/apt/lists/* \ + && curl -LsSf https://astral.sh/uv/install.sh | sh + +ENV PATH="/root/.local/bin:$PATH" + +RUN uv python install 3.11 + +RUN uv python list + +ENV PATH="/root/.uv/python/3.11/bin:$PATH" + +COPY requirements.txt . + +RUN uv venv --python 3.11 + +RUN uv pip install -r requirements.txt + +ENV PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + PORT=8000 + +EXPOSE ${PORT} + +COPY . . + +CMD ["uv", "run", "fastapi", "run", "src/main.py", "--workers", "3"] diff --git a/py/rag-service/requirements.txt b/py/rag-service/requirements.txt new file mode 100644 index 0000000..bedd807 --- /dev/null +++ b/py/rag-service/requirements.txt @@ -0,0 +1,162 @@ +aiohappyeyeballs==2.4.6 +aiohttp==3.11.12 +aiosignal==1.3.2 +annotated-types==0.7.0 +anyio==4.8.0 +asgiref==3.8.1 +asttokens==3.0.0 +attrs==25.1.0 +backoff==2.2.1 +bcrypt==4.2.1 +beautifulsoup4==4.13.3 +build==1.2.2.post1 +cachetools==5.5.1 +certifi==2024.12.14 +charset-normalizer==3.4.1 +chroma-hnswlib==0.7.6 +chromadb==0.6.3 +click==8.1.8 +coloredlogs==15.0.1 +dataclasses-json==0.6.7 +decorator==5.1.1 +Deprecated==1.2.18 +dirtyjson==1.0.8 +distro==1.9.0 +dnspython==2.7.0 +durationpy==0.9 +email_validator==2.2.0 +executing==2.2.0 +fastapi==0.115.8 +fastapi-cli==0.0.7 +filelock==3.17.0 +filetype==1.2.0 +flatbuffers==25.1.24 +frozenlist==1.5.0 +fsspec==2025.2.0 +google-auth==2.38.0 +googleapis-common-protos==1.66.0 +greenlet==3.1.1 +grpcio==1.70.0 +h11==0.14.0 +httpcore==1.0.7 +httptools==0.6.4 +httpx==0.28.1 +huggingface-hub==0.28.1 +humanfriendly==10.0 +idna==3.10 +importlib_metadata==8.5.0 +importlib_resources==6.5.2 +ipython==8.32.0 +jedi==0.19.2 +Jinja2==3.1.5 +jiter==0.8.2 +joblib==1.4.2 +kubernetes==32.0.0 +llama-cloud==0.1.11 +llama-cloud-services==0.6.0 +llama-index==0.12.16 +llama-index-agent-openai==0.4.3 +llama-index-cli==0.4.0 +llama-index-core==0.12.16.post1 +llama-index-embeddings-openai==0.3.1 +llama-index-indices-managed-llama-cloud==0.6.4 +llama-index-llms-openai==0.3.18 +llama-index-multi-modal-llms-openai==0.4.3 +llama-index-program-openai==0.3.1 +llama-index-question-gen-openai==0.3.0 +llama-index-readers-file==0.4.4 +llama-index-readers-llama-parse==0.4.0 +llama-index-vector-stores-chroma==0.4.1 +llama-parse==0.6.0 +markdown-it-py==3.0.0 +markdownify==0.14.1 +MarkupSafe==3.0.2 +marshmallow==3.26.1 +matplotlib-inline==0.1.7 +mdurl==0.1.2 +mmh3==5.1.0 +monotonic==1.6 +mpmath==1.3.0 +multidict==6.1.0 +mypy-extensions==1.0.0 +nest-asyncio==1.6.0 +networkx==3.4.2 +nltk==3.9.1 +numpy==2.2.2 +oauthlib==3.2.2 +onnxruntime==1.20.1 +openai==1.61.1 +opentelemetry-api==1.30.0 +opentelemetry-exporter-otlp-proto-common==1.30.0 +opentelemetry-exporter-otlp-proto-grpc==1.30.0 +opentelemetry-instrumentation==0.51b0 +opentelemetry-instrumentation-asgi==0.51b0 +opentelemetry-instrumentation-fastapi==0.51b0 +opentelemetry-proto==1.30.0 +opentelemetry-sdk==1.30.0 +opentelemetry-semantic-conventions==0.51b0 +opentelemetry-util-http==0.51b0 +orjson==3.10.15 +overrides==7.7.0 +packaging==24.2 +pandas==2.2.3 +parso==0.8.4 +pathspec==0.12.1 +pexpect==4.9.0 +pillow==11.1.0 +posthog==3.11.0 +prompt_toolkit==3.0.50 +propcache==0.2.1 +protobuf==5.29.3 +ptyprocess==0.7.0 +pure_eval==0.2.3 +pyasn1==0.6.1 +pyasn1_modules==0.4.1 +pydantic==2.10.6 +pydantic_core==2.27.2 +Pygments==2.19.1 +pypdf==5.2.0 +PyPika==0.48.9 +pyproject_hooks==1.2.0 +python-dateutil==2.9.0.post0 +python-dotenv==1.0.1 +python-multipart==0.0.20 +pytz==2025.1 +PyYAML==6.0.2 +regex==2024.11.6 +requests==2.32.3 +requests-oauthlib==2.0.0 +rich==13.9.4 +rich-toolkit==0.13.2 +rsa==4.9 +shellingham==1.5.4 +six==1.17.0 +sniffio==1.3.1 +soupsieve==2.6 +SQLAlchemy==2.0.38 +stack-data==0.6.3 +starlette==0.45.3 +striprtf==0.0.26 +sympy==1.13.3 +tenacity==9.0.0 +tiktoken==0.8.0 +tokenizers==0.21.0 +tqdm==4.67.1 +traitlets==5.14.3 +tree-sitter==0.21.3 +tree-sitter-languages==1.10.2 +typer==0.15.1 +typing-inspect==0.9.0 +typing_extensions==4.12.2 +tzdata==2025.1 +urllib3==2.3.0 +uvicorn==0.34.0 +uvloop==0.21.0 +watchdog==6.0.0 +watchfiles==1.0.4 +wcwidth==0.2.13 +websocket-client==1.8.0 +websockets==14.2 +wrapt==1.17.2 +yarl==1.18.3 +zipp==3.21.0 diff --git a/py/rag-service/src/libs/__init__.py b/py/rag-service/src/libs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/py/rag-service/src/libs/configs.py b/py/rag-service/src/libs/configs.py new file mode 100644 index 0000000..bbc4cf8 --- /dev/null +++ b/py/rag-service/src/libs/configs.py @@ -0,0 +1,14 @@ +import os +from pathlib import Path + +# Configuration +BASE_DATA_DIR = Path(os.environ.get("DATA_DIR", "data")) +CHROMA_PERSIST_DIR = BASE_DATA_DIR / "chroma_db" +LOG_DIR = BASE_DATA_DIR / "logs" +DB_FILE = BASE_DATA_DIR / "sqlite" / "indexing_history.db" + +# Configure directories +BASE_DATA_DIR.mkdir(parents=True, exist_ok=True) +LOG_DIR.mkdir(parents=True, exist_ok=True) +DB_FILE.parent.mkdir(parents=True, exist_ok=True) # Create sqlite directory +CHROMA_PERSIST_DIR.mkdir(parents=True, exist_ok=True) diff --git a/py/rag-service/src/libs/db.py b/py/rag-service/src/libs/db.py new file mode 100644 index 0000000..c5afb57 --- /dev/null +++ b/py/rag-service/src/libs/db.py @@ -0,0 +1,60 @@ +import sqlite3 +from collections.abc import Generator +from contextlib import contextmanager + +from libs.configs import DB_FILE + +# SQLite table schemas +CREATE_TABLES_SQL = """ +CREATE TABLE IF NOT EXISTS indexing_history ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + uri TEXT NOT NULL, + content_hash TEXT NOT NULL, + status TEXT NOT NULL, + timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, + error_message TEXT, + document_id TEXT, + metadata TEXT +); + +CREATE INDEX IF NOT EXISTS idx_uri ON indexing_history(uri); +CREATE INDEX IF NOT EXISTS idx_document_id ON indexing_history(document_id); +CREATE INDEX IF NOT EXISTS idx_content_hash ON indexing_history(content_hash); + +CREATE TABLE IF NOT EXISTS resources ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + uri TEXT NOT NULL UNIQUE, + type TEXT NOT NULL, -- 'path' or 'https' + status TEXT NOT NULL DEFAULT 'active', -- 'active' or 'inactive' + indexing_status TEXT NOT NULL DEFAULT 'pending', -- 'pending', 'indexing', 'indexed', 'failed' + indexing_status_message TEXT, + indexing_started_at DATETIME, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + last_indexed_at DATETIME, + last_error TEXT +); + +CREATE INDEX IF NOT EXISTS idx_resources_name ON resources(name); +CREATE INDEX IF NOT EXISTS idx_resources_uri ON resources(uri); +CREATE INDEX IF NOT EXISTS idx_resources_status ON resources(status); +CREATE INDEX IF NOT EXISTS idx_status ON indexing_history(status); +""" + + +@contextmanager +def get_db_connection() -> Generator[sqlite3.Connection, None, None]: + """Get a database connection.""" + conn = sqlite3.connect(DB_FILE) + conn.row_factory = sqlite3.Row + try: + yield conn + finally: + conn.close() + + +def init_db() -> None: + """Initialize the SQLite database.""" + with get_db_connection() as conn: + conn.executescript(CREATE_TABLES_SQL) + conn.commit() diff --git a/py/rag-service/src/libs/logger.py b/py/rag-service/src/libs/logger.py new file mode 100644 index 0000000..2c81de6 --- /dev/null +++ b/py/rag-service/src/libs/logger.py @@ -0,0 +1,16 @@ +import logging +from datetime import datetime + +from libs.configs import LOG_DIR + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", + handlers=[ + logging.FileHandler( + LOG_DIR / f"rag_service_{datetime.now().astimezone().strftime('%Y%m%d')}.log", + ), + logging.StreamHandler(), + ], +) +logger = logging.getLogger(__name__) diff --git a/py/rag-service/src/libs/utils.py b/py/rag-service/src/libs/utils.py new file mode 100644 index 0000000..c8cf9aa --- /dev/null +++ b/py/rag-service/src/libs/utils.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import re +from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from llama_index.core.schema import BaseNode + +PATTERN_URI_PART = re.compile(r"(?P.+)__part_\d+") +METADATA_KEY_URI = "uri" + + +def uri_to_path(uri: str) -> Path: + """Convert URI to path.""" + return Path(uri.replace("file://", "")) + + +def path_to_uri(file_path: Path) -> str: + """Convert path to URI.""" + uri = file_path.as_uri() + if file_path.is_dir(): + uri += "/" + return uri + + +def is_local_uri(uri: str) -> bool: + """Check if the URI is a path URI.""" + return uri.startswith("file://") + + +def is_remote_uri(uri: str) -> bool: + """Check if the URI is an HTTPS URI or HTTP URI.""" + return uri.startswith(("https://", "http://")) + + +def is_path_node(node: BaseNode) -> bool: + """Check if the node is a file node.""" + uri = get_node_uri(node) + if not uri: + return False + return is_local_uri(uri) + + +def get_node_uri(node: BaseNode) -> str | None: + """Get URI from node metadata.""" + uri = node.metadata.get(METADATA_KEY_URI) + if not uri: + doc_id = getattr(node, "doc_id", None) + if doc_id: + match = PATTERN_URI_PART.match(doc_id) + uri = match.group("uri") if match else doc_id + if uri: + if uri.startswith("/"): + uri = f"file://{uri}" + return uri + return None + + +def inject_uri_to_node(node: BaseNode) -> None: + """Inject file path into node metadata.""" + if METADATA_KEY_URI in node.metadata: + return + uri = get_node_uri(node) + if uri: + node.metadata[METADATA_KEY_URI] = uri diff --git a/py/rag-service/src/main.py b/py/rag-service/src/main.py new file mode 100644 index 0000000..de2d814 --- /dev/null +++ b/py/rag-service/src/main.py @@ -0,0 +1,1114 @@ +"""RAG Service API for managing document indexing and retrieval.""" # noqa: INP001 + +from __future__ import annotations + +import asyncio +import fcntl +import json +import multiprocessing +import os +import re +import threading +import time +from concurrent.futures import ThreadPoolExecutor +from contextlib import asynccontextmanager +from pathlib import Path +from typing import TYPE_CHECKING +from urllib.parse import urljoin, urlparse + +import chromadb +import httpx +import pathspec +from fastapi import BackgroundTasks, FastAPI, HTTPException +from libs.configs import ( + BASE_DATA_DIR, + CHROMA_PERSIST_DIR, +) +from libs.db import init_db +from libs.logger import logger +from libs.utils import ( + get_node_uri, + inject_uri_to_node, + is_local_uri, + is_path_node, + is_remote_uri, + path_to_uri, + uri_to_path, +) +from llama_index.core import ( + Settings, + SimpleDirectoryReader, + StorageContext, + VectorStoreIndex, + load_index_from_storage, +) +from llama_index.core.node_parser import CodeSplitter +from llama_index.core.schema import Document +from llama_index.embeddings.openai import OpenAIEmbedding +from llama_index.vector_stores.chroma import ChromaVectorStore +from markdownify import markdownify as md +from models.indexing_history import IndexingHistory # noqa: TC002 +from models.resource import Resource +from pydantic import BaseModel, Field +from services.indexing_history import indexing_history_service +from services.resource import resource_service +from watchdog.events import FileSystemEvent, FileSystemEventHandler +from watchdog.observers import Observer + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + + from llama_index.core.schema import NodeWithScore, QueryBundle + from watchdog.observers.api import BaseObserver + +# Lock file for leader election +LOCK_FILE = BASE_DATA_DIR / "leader.lock" + + +def try_acquire_leadership() -> bool: + """Try to acquire leadership using file lock.""" + try: + # Ensure the lock file exists + LOCK_FILE.parent.mkdir(parents=True, exist_ok=True) + LOCK_FILE.touch(exist_ok=True) + + # Try to acquire an exclusive lock + lock_fd = os.open(str(LOCK_FILE), os.O_RDWR) + fcntl.flock(lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB) + + # Write current process ID to lock file + os.truncate(lock_fd, 0) + os.write(lock_fd, str(os.getpid()).encode()) + + return True + except OSError: + return False + + +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001 + """Initialize services on startup.""" + # Try to become leader if no worker_id is set + + is_leader = try_acquire_leadership() + + # Only run initialization in the leader + if is_leader: + logger.info("Starting RAG service as leader (PID: %d)...", os.getpid()) + + # Get all active resources + active_resources = [r for r in resource_service.get_all_resources() if r.status == "active"] + logger.info("Found %d active resources to sync", len(active_resources)) + + for resource in active_resources: + try: + if is_local_uri(resource.uri): + directory = uri_to_path(resource.uri) + if not directory.exists(): + logger.error("Directory not found: %s", directory) + resource_service.update_resource_status(resource.uri, "error", "Directory not found") + continue + + # Start file system watcher + event_handler = FileSystemHandler(directory=directory) + observer = Observer() + observer.schedule(event_handler, str(directory), recursive=True) + observer.start() + watched_resources[resource.uri] = observer + + # Start indexing + await index_local_resource_async(resource) + + elif is_remote_uri(resource.uri): + if not is_remote_resource_exists(resource.uri): + logger.error("HTTPS resource not found: %s", resource.uri) + resource_service.update_resource_status(resource.uri, "error", "remote resource not found") + continue + + # Start indexing + await index_remote_resource_async(resource) + + logger.info("Successfully synced resource: %s", resource.uri) + + except (OSError, ValueError, RuntimeError) as e: + error_msg = f"Failed to sync resource {resource.uri}: {e}" + logger.exception(error_msg) + resource_service.update_resource_status(resource.uri, "error", error_msg) + + yield + + # Cleanup on shutdown (only in leader) + if is_leader: + for observer in watched_resources.values(): + observer.stop() + observer.join() + + +app = FastAPI( + title="RAG Service API", + description=""" + RAG (Retrieval-Augmented Generation) Service API for managing document indexing and retrieval. + + ## Features + * Add resources for document watching and indexing + * Remove watched resources + * Retrieve relevant information from indexed resources + * Monitor indexing status + """, + version="1.0.0", + docs_url="/docs", + lifespan=lifespan, + redoc_url="/redoc", +) + +# Constants +SIMILARITY_THRESHOLD = 0.95 +MAX_SAMPLE_SIZE = 100 +BATCH_PROCESSING_DELAY = 1 + +# number of cpu cores to use for parallel processing +MAX_WORKERS = multiprocessing.cpu_count() +BATCH_SIZE = 40 # Number of documents to process per batch + +logger.info("data dir: %s", BASE_DATA_DIR.resolve()) + +# Global variables +watched_resources: dict[str, BaseObserver] = {} # Directory path -> Observer instance mapping +file_last_modified: dict[Path, float] = {} # File path -> Last modified time mapping +index_lock = threading.Lock() + +code_ext_map = { + ".py": "python", + ".js": "javascript", + ".ts": "typescript", + ".jsx": "javascript", + ".tsx": "typescript", + ".vue": "vue", + ".go": "go", + ".java": "java", + ".cpp": "cpp", + ".c": "c", + ".h": "cpp", + ".rs": "rust", + ".rb": "ruby", + ".php": "php", + ".scala": "scala", + ".kt": "kotlin", + ".swift": "swift", + ".lua": "lua", + ".pl": "perl", + ".pm": "perl", + ".t": "perl", + ".pm6": "perl", + ".m": "perl", +} + +required_exts = [ + ".txt", + ".pdf", + ".docx", + ".xlsx", + ".pptx", + ".rst", + ".json", + ".ini", + ".conf", + ".toml", + ".md", + ".markdown", + ".csv", + ".tsv", + ".html", + ".htm", + ".xml", + ".yaml", + ".yml", + ".css", + ".scss", + ".less", + ".sass", + ".styl", + ".sh", + ".bash", + ".zsh", + ".fish", + ".rb", + ".java", + ".go", + ".ts", + ".tsx", + ".js", + ".jsx", + ".vue", + ".py", + ".php", + ".c", + ".cpp", + ".h", + ".rs", + ".swift", + ".kt", + ".lua", + ".perl", + ".pl", + ".pm", + ".t", + ".pm6", + ".m", +] + + +http_headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36", +} + + +def is_remote_resource_exists(url: str) -> bool: + """Check if a URL exists.""" + try: + response = httpx.head(url, headers=http_headers) + return response.status_code in {httpx.codes.OK, httpx.codes.MOVED_PERMANENTLY, httpx.codes.FOUND} + except (OSError, ValueError, RuntimeError) as e: + logger.error("Error checking if URL exists %s: %s", url, e) + return False + + +def fetch_markdown(url: str) -> str: + """Fetch markdown content from a URL.""" + try: + logger.info("Fetching markdown content from %s", url) + response = httpx.get(url, headers=http_headers) + if response.status_code == httpx.codes.OK: + return md(response.text) + return "" + except (OSError, ValueError, RuntimeError) as e: + logger.error("Error fetching markdown content %s: %s", url, e) + return "" + + +def markdown_to_links(base_url: str, markdown: str) -> list[str]: + """Extract links from markdown content.""" + links = [] + seek = {base_url} + parsed_url = urlparse(base_url) + domain = parsed_url.netloc + scheme = parsed_url.scheme + for match in re.finditer(r"\[(.*?)\]\((.*?)\)", markdown): + url = match.group(2) + if not url.startswith(scheme): + url = urljoin(base_url, url) + if urlparse(url).netloc != domain: + continue + if url in seek: + continue + seek.add(url) + links.append(url) + return links + + +# Initialize database +init_db() + +# Initialize ChromaDB and LlamaIndex services +chroma_client = chromadb.PersistentClient(path=str(CHROMA_PERSIST_DIR)) +chroma_collection = chroma_client.get_or_create_collection("documents") +vector_store = ChromaVectorStore(chroma_collection=chroma_collection) +storage_context = StorageContext.from_defaults(vector_store=vector_store) +embed_model = OpenAIEmbedding() +Settings.embed_model = embed_model + + +try: + index = load_index_from_storage(storage_context) +except (OSError, ValueError) as e: + logger.error("Failed to load index from storage: %s", e) + index = VectorStoreIndex([], storage_context=storage_context) + + +class ResourceRequest(BaseModel): + """Request model for resource operations.""" + + name: str = Field(..., description="Name of the resource to watch and index") + uri: str = Field(..., description="URI of the resource to watch and index") + + +class SourceDocument(BaseModel): + """Model for source document information.""" + + uri: str = Field(..., description="URI of the source") + content: str = Field(..., description="Content snippet from the document") + score: float | None = Field(None, description="Relevance score of the document") + + +class RetrieveRequest(BaseModel): + """Request model for information retrieval.""" + + query: str = Field(..., description="The query text to search for in the indexed documents") + base_uri: str = Field(..., description="The base URI to search in") + top_k: int | None = Field(5, description="Number of top results to return", ge=1, le=20) + + +class RetrieveResponse(BaseModel): + """Response model for information retrieval.""" + + response: str = Field(..., description="Generated response to the query") + sources: list[SourceDocument] = Field(..., description="List of source documents used") + + +class FileSystemHandler(FileSystemEventHandler): + """Handler for file system events.""" + + def __init__(self: FileSystemHandler, directory: Path) -> None: + """Initialize the handler.""" + self.directory = directory + + def on_modified(self: FileSystemHandler, event: FileSystemEvent) -> None: + """Handle file modification events.""" + if not event.is_directory and not str(event.src_path).endswith(".tmp"): + self.handle_file_change(Path(str(event.src_path))) + + def on_created(self: FileSystemHandler, event: FileSystemEvent) -> None: + """Handle file creation events.""" + if not event.is_directory and not str(event.src_path).endswith(".tmp"): + self.handle_file_change(Path(str(event.src_path))) + + def handle_file_change(self: FileSystemHandler, file_path: Path) -> None: + """Handle changes to a file.""" + current_time = time.time() + + abs_file_path = file_path + if not Path(abs_file_path).is_absolute(): + abs_file_path = Path(self.directory, file_path) + + # Check if the file was recently processed + if abs_file_path in file_last_modified and current_time - file_last_modified[abs_file_path] < BATCH_PROCESSING_DELAY: + return + + file_last_modified[abs_file_path] = current_time + threading.Thread(target=update_index_for_file, args=(self.directory, abs_file_path)).start() + + +def is_valid_text(text: str) -> bool: + """Check if the text is valid and readable.""" + if not text: + logger.debug("Text content is empty") + return False + + # Check if the text mainly contains printable characters + printable_ratio = sum(1 for c in text if c.isprintable() or c in "\n\r\t") / len(text) + if printable_ratio <= SIMILARITY_THRESHOLD: + logger.debug("Printable character ratio too low: %.2f%%", printable_ratio * 100) + # Output a small sample for analysis + sample = text[:MAX_SAMPLE_SIZE] if len(text) > MAX_SAMPLE_SIZE else text + logger.debug("Text sample: %r", sample) + return printable_ratio > SIMILARITY_THRESHOLD + + +def clean_text(text: str) -> str: + """Clean text content by removing non-printable characters.""" + return "".join(char for char in text if char.isprintable() or char in "\n\r\t") + + +def process_document_batch(documents: list[Document]) -> bool: # noqa: PLR0915, C901, PLR0912, RUF100 + """Process a batch of documents for embedding.""" + try: + # Filter out invalid and already processed documents + valid_documents = [] + invalid_documents = [] + for doc in documents: + doc_id = doc.doc_id + + # Check if document with same hash has already been successfully processed + status_records = indexing_history_service.get_indexing_status(doc=doc) + if status_records and status_records[0].status == "completed": + logger.info("Document with same hash already processed, skipping: %s", doc.doc_id) + continue + + logger.info("Processing document: %s", doc.doc_id) + try: + content = doc.get_content() + + # If content is bytes type, try to decode + if isinstance(content, bytes): + try: + content = content.decode("utf-8", errors="replace") + except (UnicodeDecodeError, OSError) as e: + error_msg = f"Unable to decode document content: {doc_id}, error: {e!s}" + logger.warning(error_msg) + indexing_history_service.update_indexing_status(doc, "failed", error_message=error_msg) + invalid_documents.append(doc_id) + continue + + # Ensure content is string type + content = str(content) + + if not is_valid_text(content): + error_msg = f"Invalid document content: {doc_id}" + logger.warning(error_msg) + indexing_history_service.update_indexing_status(doc, "failed", error_message=error_msg) + invalid_documents.append(doc_id) + continue + + # Create new document object with cleaned content + from llama_index.core.schema import Document + + cleaned_content = clean_text(content) + metadata = getattr(doc, "metadata", {}).copy() + + new_doc = Document( + text=cleaned_content, + doc_id=doc_id, + metadata=metadata, + ) + inject_uri_to_node(new_doc) + valid_documents.append(new_doc) + # Update status to indexing for valid documents + indexing_history_service.update_indexing_status(doc, "indexing") + + except OSError as e: + error_msg = f"Document processing failed: {doc_id}, error: {e!s}" + logger.exception(error_msg) + indexing_history_service.update_indexing_status(doc, "failed", error_message=error_msg) + invalid_documents.append(doc_id) + + try: + if valid_documents: + with index_lock: + index.refresh_ref_docs(valid_documents) + + # Update status to completed for successfully processed documents + for doc in valid_documents: + indexing_history_service.update_indexing_status( + doc, + "completed", + metadata=doc.metadata, + ) + + return not invalid_documents + + except OSError as e: + error_msg = f"Batch indexing failed: {e!s}" + logger.exception(error_msg) + # Update status to failed for all documents in the batch + for doc in valid_documents: + indexing_history_service.update_indexing_status(doc, "failed", error_message=error_msg) + return False + + except OSError as e: + error_msg = f"Batch processing failed: {e!s}" + logger.exception(error_msg) + # Update status to failed for all documents in the batch + for doc in documents: + indexing_history_service.update_indexing_status(doc, "failed", error_message=error_msg) + return False + + +def get_pathspec(directory: Path) -> pathspec.PathSpec | None: + """Get pathspec for the directory.""" + gitignore_path = directory / ".gitignore" + if not gitignore_path.exists(): + return None + + # Read gitignore patterns + with gitignore_path.open("r", encoding="utf-8") as f: + return pathspec.GitIgnoreSpec.from_lines([*f.readlines(), ".git/"]) + + +def scan_directory(directory: Path) -> list[str]: + """Scan directory and return a list of matched files.""" + spec = get_pathspec(directory) + + matched_files = [] + + for root, _, files in os.walk(directory): + file_paths = [str(Path(root) / file) for file in files] + if not spec: + matched_files.extend(file_paths) + continue + matched_files.extend([file for file in file_paths if not spec.match_file(file)]) + + return matched_files + + +def update_index_for_file(directory: Path, abs_file_path: Path) -> None: + """Update the index for a single file.""" + logger.info("Starting to index file: %s", abs_file_path) + + rel_file_path = abs_file_path.relative_to(directory) + + spec = get_pathspec(directory) + if spec and spec.match_file(rel_file_path): + logger.info("File is ignored, skipping: %s", abs_file_path) + return + + resource = resource_service.get_resource(path_to_uri(directory)) + if not resource: + logger.error("Resource not found for directory: %s", directory) + return + + resource_service.update_resource_indexing_status(resource.uri, "indexing", "") + + documents = SimpleDirectoryReader( + input_files=[abs_file_path], + filename_as_id=True, + required_exts=required_exts, + ).load_data() + + logger.info("Updating index: %s", abs_file_path) + processed_documents = split_documents(documents) + success = process_document_batch(processed_documents) + + if success: + resource_service.update_resource_indexing_status(resource.uri, "indexed", "") + logger.info("File indexing completed: %s", abs_file_path) + else: + resource_service.update_resource_indexing_status(resource.uri, "failed", "unknown error") + logger.error("File indexing failed: %s", abs_file_path) + + +def split_documents(documents: list[Document]) -> list[Document]: + """Split documents into code and non-code documents.""" + # Create file parser configuration + # Initialize CodeSplitter + code_splitter = CodeSplitter( + language="python", # Default is python, will auto-detect based on file extension + chunk_lines=80, # Maximum number of lines per code block + chunk_lines_overlap=15, # Number of overlapping lines to maintain context + max_chars=1500, # Maximum number of characters per block + ) + # Split code documents using CodeSplitter + processed_documents = [] + for doc in documents: + uri = get_node_uri(doc) + if not uri: + continue + if not is_path_node(doc): + processed_documents.append(doc) + continue + file_path = uri_to_path(uri) + file_ext = file_path.suffix.lower() + if file_ext in code_ext_map: + # Apply CodeSplitter to code files + code_splitter.language = code_ext_map.get(file_ext, "python") + + try: + texts = code_splitter.split_text(doc.get_content()) + except ValueError as e: + logger.error("Error splitting document: %s, so skipping split, error: %s", doc.doc_id, str(e)) + processed_documents.append(doc) + continue + + for i, text in enumerate(texts): + from llama_index.core.schema import Document + + new_doc = Document( + text=text, + doc_id=f"{doc.doc_id}__part_{i}", + metadata={ + **doc.metadata, + "chunk_number": i, + "total_chunks": len(texts), + "language": code_splitter.language, + "orig_doc_id": doc.doc_id, + }, + ) + processed_documents.append(new_doc) + else: + doc.metadata["orig_doc_id"] = doc.doc_id + # Add non-code files directly + processed_documents.append(doc) + return processed_documents + + +async def index_remote_resource_async(resource: Resource) -> None: + """Asynchronously index a remote resource.""" + resource_service.update_resource_indexing_status(resource.uri, "indexing", "") + url = resource.uri + try: + logger.info("Loading resource content: %s", url) + + # Fetch markdown content + markdown = fetch_markdown(url) + + link_md_pairs = [(url, markdown)] + + # Extract links from markdown + links = markdown_to_links(url, markdown) + + logger.info("Found %d sub links", len(links)) + logger.debug("Link list: %s", links) + + # Use thread pool for parallel batch processing + loop = asyncio.get_event_loop() + with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: + mds: list[str] = await loop.run_in_executor( + executor, + lambda: list(executor.map(fetch_markdown, links)), + ) + + zipped = zip(links, mds, strict=True) # pyright: ignore + link_md_pairs.extend(zipped) + + # Create documents from links + documents = [Document(text=markdown, doc_id=link) for link, markdown in link_md_pairs] + + logger.info("Found %d documents", len(documents)) + logger.debug("Document list: %s", [doc.doc_id for doc in documents]) + + # Process documents in batches + total_documents = len(documents) + batches = [documents[i : i + BATCH_SIZE] for i in range(0, total_documents, BATCH_SIZE)] + logger.info("Splitting documents into %d batches for processing", len(batches)) + + with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: + results = await loop.run_in_executor( + executor, + lambda: list(executor.map(process_document_batch, batches)), + ) + + # Check processing results + if all(results): + logger.info("Resource %s indexing completed", url) + resource_service.update_resource_indexing_status(resource.uri, "indexed", "") + else: + failed_batches = len([r for r in results if not r]) + error_msg = f"Some batches failed processing ({failed_batches}/{len(batches)})" + logger.error(error_msg) + resource_service.update_resource_indexing_status(resource.uri, "indexed", error_msg) + + except OSError as e: + error_msg = f"Resource indexing failed: {url}" + logger.exception(error_msg) + resource_service.update_resource_indexing_status(resource.uri, "failed", error_msg) + raise e # noqa: TRY201 + + +async def index_local_resource_async(resource: Resource) -> None: + """Asynchronously index a directory.""" + resource_service.update_resource_indexing_status(resource.uri, "indexing", "") + directory_path = uri_to_path(resource.uri) + try: + logger.info("Loading directory content: %s", directory_path) + + from llama_index.core.readers.file.base import SimpleDirectoryReader + + documents = SimpleDirectoryReader( + input_files=scan_directory(directory_path), + filename_as_id=True, + required_exts=required_exts, + ).load_data() + + processed_documents = split_documents(documents) + + logger.info("Found %d documents", len(processed_documents)) + logger.debug("Document list: %s", [doc.doc_id for doc in processed_documents]) + + # Process documents in batches + total_documents = len(processed_documents) + batches = [processed_documents[i : i + BATCH_SIZE] for i in range(0, total_documents, BATCH_SIZE)] + logger.info("Splitting documents into %d batches for processing", len(batches)) + + # Use thread pool for parallel batch processing + loop = asyncio.get_event_loop() + with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: + results = await loop.run_in_executor( + executor, + lambda: list(executor.map(process_document_batch, batches)), + ) + + # Check processing results + if all(results): + logger.info("Directory %s indexing completed", directory_path) + resource_service.update_resource_indexing_status(resource.uri, "indexed", "") + else: + failed_batches = len([r for r in results if not r]) + error_msg = f"Some batches failed processing ({failed_batches}/{len(batches)})" + resource_service.update_resource_indexing_status(resource.uri, "indexed", error_msg) + logger.error(error_msg) + + except OSError as e: + error_msg = f"Directory indexing failed: {directory_path}" + resource_service.update_resource_indexing_status(resource.uri, "failed", error_msg) + logger.exception(error_msg) + raise e # noqa: TRY201 + + +@app.get("/api/v1/readyz") +async def readiness_probe() -> dict[str, str]: + """Readiness probe endpoint.""" + return {"status": "ok"} + + +@app.post( + "/api/v1/add_resource", + response_model="dict[str, str]", + summary="Add a resource for watching and indexing", + description=""" + Adds a resource to the watch list and starts indexing all existing documents in it asynchronously. + """, + responses={ + 200: {"description": "Resource successfully added and indexing started"}, + 404: {"description": "Resource not found"}, + 400: {"description": "Resource already being watched"}, + }, +) +async def add_resource(request: ResourceRequest, background_tasks: BackgroundTasks): # noqa: D103, ANN201, C901 + # Check if resource already exists + resource = resource_service.get_resource(request.uri) + if resource and resource.status == "active": + return { + "status": "success", + "message": f"Resource {request.uri} added and indexing started in background", + } + + resource_type = "local" + + async def background_task(resource: Resource) -> None: + pass + + if is_local_uri(request.uri): + directory = uri_to_path(request.uri) + if not directory.exists(): + raise HTTPException(status_code=404, detail="Directory not found") + + if not directory.is_dir(): + raise HTTPException(status_code=400, detail="Not a directory") + + # Create observer + event_handler = FileSystemHandler(directory=directory) + observer = Observer() + observer.schedule(event_handler, str(directory), recursive=True) + observer.start() + watched_resources[request.uri] = observer + + background_task = index_local_resource_async + elif is_remote_uri(request.uri): + if not is_remote_resource_exists(request.uri): + raise HTTPException(status_code=404, detail="web resource not found") + + resource_type = "remote" + + background_task = index_remote_resource_async + else: + raise HTTPException(status_code=400, detail=f"Invalid URI: {request.uri}") + + if resource: + if resource.name != request.name: + raise HTTPException(status_code=400, detail=f"Resource name cannot be changed: {resource.name}") + + resource_service.update_resource_status(resource.uri, "active") + else: + exists_resource = resource_service.get_resource_by_name(request.name) + if exists_resource: + raise HTTPException(status_code=400, detail="Resource with same name already exists") + # Add to database + resource = Resource( + id=None, + name=request.name, + uri=request.uri, + type=resource_type, + status="active", + indexing_status="pending", + indexing_status_message=None, + indexing_started_at=None, + last_indexed_at=None, + last_error=None, + ) + resource_service.add_resource_to_db(resource) + background_tasks.add_task(background_task, resource) + + return { + "status": "success", + "message": f"Resource {request.uri} added and indexing started in background", + } + + +@app.post( + "/api/v1/remove_resource", + response_model="dict[str, str]", + summary="Remove a watched resource", + description="Stops watching and indexing the specified resource", + responses={ + 200: {"description": "Resource successfully removed from watch list"}, + 404: {"description": "Resource not found in watch list"}, + }, +) +async def remove_resource(request: ResourceRequest): # noqa: D103, ANN201 + resource = resource_service.get_resource(request.uri) + if not resource or resource.status != "active": + raise HTTPException(status_code=404, detail="Resource not being watched") + + if request.uri in watched_resources: + # Stop watching + observer = watched_resources[request.uri] + observer.stop() + observer.join() + del watched_resources[request.uri] + + # Update database status + resource_service.update_resource_status(request.uri, "inactive") + + return {"status": "success", "message": f"Resource {request.uri} removed"} + + +@app.post( + "/api/v1/retrieve", + response_model=RetrieveResponse, + summary="Retrieve information from indexed documents", + description=""" + Performs a semantic search over all indexed documents and returns relevant information. + The response includes both the answer and the source documents used to generate it. + """, + responses={ + 200: {"description": "Successfully retrieved information"}, + 500: {"description": "Internal server error during retrieval"}, + }, +) +async def retrieve(request: RetrieveRequest): # noqa: D103, ANN201, C901, PLR0915 + if is_local_uri(request.base_uri): + directory = uri_to_path(request.base_uri) + # Validate directory exists + if not directory.exists(): + raise HTTPException(status_code=404, detail=f"Directory not found: {request.base_uri}") + + logger.info( + "Received retrieval request: %s for base uri: %s", + request.query, + request.base_uri, + ) + + cached_file_contents = {} + + # Create a filter function to only include documents from the specified directory + def filter_documents(node: NodeWithScore) -> bool: + uri = get_node_uri(node.node) + if not uri: + return False + if is_path_node(node.node): + file_path = uri_to_path(uri) + # Check if the file path starts with the specified directory + file_path = file_path.resolve() + directory = uri_to_path(request.base_uri).resolve() + # Check if directory is a parent of file_path + try: + file_path.relative_to(directory) + if not file_path.exists(): + logger.warning("File not found: %s", file_path) + return False + content = cached_file_contents.get(file_path) + if content is None: + with file_path.open("r", encoding="utf-8") as f: + content = f.read() + cached_file_contents[file_path] = content + if node.node.get_content() not in content: + logger.warning("File content does not match: %s", file_path) + return False + return True + except ValueError: + return False + if uri == request.base_uri: + return True + base_uri = request.base_uri + if not base_uri.endswith(os.path.sep): + base_uri += os.path.sep + return uri.startswith(base_uri) + + from llama_index.core.postprocessor import MetadataReplacementPostProcessor + + # Create a custom post processor + class ResourceFilterPostProcessor(MetadataReplacementPostProcessor): + """Post-processor for filtering nodes based on directory.""" + + def __init__(self: ResourceFilterPostProcessor) -> None: + """Initialize the post-processor.""" + super().__init__(target_metadata_key="filtered") + + def postprocess_nodes( + self: ResourceFilterPostProcessor, + nodes: list[NodeWithScore], + query_bundle: QueryBundle | None = None, # noqa: ARG002, pyright: ignore + query_str: str | None = None, # noqa: ARG002, pyright: ignore + ) -> list[NodeWithScore]: + """ + Filter nodes based on directory path. + + Args: + ---- + nodes: The nodes to process + query_bundle: Optional query bundle for the query + query_str: Optional query string + + Returns: + ------- + List of filtered nodes + + """ + return [node for node in nodes if filter_documents(node)] + + # Create query engine with the filter + query_engine = index.as_query_engine( + node_postprocessors=[ResourceFilterPostProcessor()], + ) + + logger.info("Executing retrieval query") + response = query_engine.query(request.query) + + # If no documents were found in the specified directory + if not response.source_nodes: + raise HTTPException( + status_code=404, + detail=f"No relevant documents found in uri: {request.base_uri}", + ) + + # Process source documents, ensure readable text + sources = [] + for node in response.source_nodes[: request.top_k]: + try: + content = node.node.get_content() + + uri = get_node_uri(node.node) + + # Handle byte-type content + if isinstance(content, bytes): + try: + content = content.decode("utf-8", errors="replace") + except UnicodeDecodeError as e: + logger.warning( + "Unable to decode document content: %s, error: %s", + uri, + str(e), + ) + continue + + # Validate and clean text + if is_valid_text(str(content)): + cleaned_content = clean_text(str(content)) + # Add document source information with file path + doc_info = { + "uri": uri, + "content": cleaned_content, + "score": float(node.score) if hasattr(node, "score") else None, + } + sources.append(doc_info) + else: + logger.warning("Skipping invalid document content: %s", uri) + + except (OSError, UnicodeDecodeError, json.JSONDecodeError): + logger.warning("Error processing source document", exc_info=True) + continue + + logger.info("Retrieval completed, found %d relevant documents", len(sources)) + + # Process response text similarly + response_text = str(response) + response_text = "".join(char for char in response_text if char.isprintable() or char in "\n\r\t") + + return { + "response": response_text, + "sources": sources, + } + + +class IndexingStatusRequest(BaseModel): + """Request model for indexing status.""" + + uri: str = Field(..., description="URI of the resource to get indexing status for") + + +class IndexingStatusResponse(BaseModel): + """Model for indexing status response.""" + + uri: str = Field(..., description="URI of the resource being monitored") + is_watched: bool = Field(..., description="Whether the directory is currently being watched") + files: list[IndexingHistory] = Field(..., description="List of files and their indexing status") + total_files: int = Field(..., description="Total number of files processed in this directory") + status_summary: dict[str, int] = Field( + ..., + description="Summary of indexing statuses (count by status)", + ) + + +@app.post( + "/api/v1/indexing-status", + response_model=IndexingStatusResponse, + summary="Get indexing status for a resource", + description=""" + Returns the current indexing status for all files in the specified resource, including: + * Whether the resource is being watched + * Status of each files in the resource + """, + responses={ + 200: {"description": "Successfully retrieved indexing status"}, + 404: {"description": "Resource not found"}, + }, +) +async def get_indexing_status_for_resource(request: IndexingStatusRequest): # noqa: D103, ANN201 + resource_files = [] + status_counts = {} + if is_local_uri(request.uri): + directory = uri_to_path(request.uri).resolve() + if not directory.exists(): + raise HTTPException(status_code=404, detail=f"Directory not found: {directory}") + + # Get indexing history records for the specific directory + resource_files = indexing_history_service.get_indexing_status(base_uri=request.uri) + + logger.info("Found %d files in resource %s", len(resource_files), request.uri) + for file in resource_files: + logger.debug("File status: %s - %s", file.uri, file.status) + + # Count files by status + for file in resource_files: + status_counts[file.status] = status_counts.get(file.status, 0) + 1 + + return IndexingStatusResponse( + uri=request.uri, + is_watched=request.uri in watched_resources, + files=resource_files, + total_files=len(resource_files), + status_summary=status_counts, + ) + + +class ResourceListResponse(BaseModel): + """Response model for listing resources.""" + + resources: list[Resource] = Field(..., description="List of all resources") + total_count: int = Field(..., description="Total number of resources") + status_summary: dict[str, int] = Field( + ..., + description="Summary of resource statuses (count by status)", + ) + + +@app.get( + "/api/v1/resources", + response_model=ResourceListResponse, + summary="List all resources", + description=""" + Returns a list of all resources that have been added to the system, including: + * Resource URI + * Resource type (path/https) + * Current status + * Last indexed timestamp + * Any errors + """, + responses={ + 200: {"description": "Successfully retrieved resource list"}, + }, +) +async def list_resources() -> ResourceListResponse: + """Get all resources and their current status.""" + # Get all resources from database + resources = resource_service.get_all_resources() + + # Count resources by status + status_counts = {} + for resource in resources: + status_counts[resource.status] = status_counts.get(resource.status, 0) + 1 + + return ResourceListResponse( + resources=resources, + total_count=len(resources), + status_summary=status_counts, + ) diff --git a/py/rag-service/src/models/__init__.py b/py/rag-service/src/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/py/rag-service/src/models/indexing_history.py b/py/rag-service/src/models/indexing_history.py new file mode 100644 index 0000000..b4b0958 --- /dev/null +++ b/py/rag-service/src/models/indexing_history.py @@ -0,0 +1,19 @@ +"""Indexing History Model.""" + +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, Field + + +class IndexingHistory(BaseModel): + """Model for indexing history record.""" + + id: int | None = Field(None, description="Record ID") + uri: str = Field(..., description="URI of the indexed file") + content_hash: str = Field(..., description="MD5 hash of the file content") + status: str = Field(..., description="Indexing status (indexing/completed/failed)") + timestamp: datetime = Field(default_factory=datetime.now, description="Record timestamp") + error_message: str | None = Field(None, description="Error message if failed") + document_id: str | None = Field(None, description="Document ID in the index") + metadata: dict[str, Any] | None = Field(None, description="Additional metadata") diff --git a/py/rag-service/src/models/resource.py b/py/rag-service/src/models/resource.py new file mode 100644 index 0000000..0e00b2e --- /dev/null +++ b/py/rag-service/src/models/resource.py @@ -0,0 +1,25 @@ +"""Resource Model.""" + +from datetime import datetime +from typing import Literal + +from pydantic import BaseModel, Field + + +class Resource(BaseModel): + """Model for resource record.""" + + id: int | None = Field(None, description="Resource ID") + name: str = Field(..., description="Name of the resource") + uri: str = Field(..., description="URI of the resource") + type: Literal["local", "remote"] = Field(..., description="Type of resource (path/https)") + status: str = Field("active", description="Status of resource (active/inactive)") + indexing_status: Literal["pending", "indexing", "indexed", "failed"] = Field( + "pending", + description="Indexing status (pending/indexing/indexed/failed)", + ) + indexing_status_message: str | None = Field(None, description="Indexing status message") + created_at: datetime = Field(default_factory=datetime.now, description="Creation timestamp") + indexing_started_at: datetime | None = Field(None, description="Indexing start timestamp") + last_indexed_at: datetime | None = Field(None, description="Last indexing timestamp") + last_error: str | None = Field(None, description="Last error message if any") diff --git a/py/rag-service/src/services/__init__.py b/py/rag-service/src/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/py/rag-service/src/services/indexing_history.py b/py/rag-service/src/services/indexing_history.py new file mode 100644 index 0000000..71ba59e --- /dev/null +++ b/py/rag-service/src/services/indexing_history.py @@ -0,0 +1,174 @@ +import json +import os +from datetime import datetime +from typing import Any + +from libs.db import get_db_connection +from libs.logger import logger +from libs.utils import get_node_uri +from llama_index.core.schema import Document +from models.indexing_history import IndexingHistory + + +class IndexingHistoryService: + def delete_indexing_status(self, uri: str) -> None: + """Delete indexing status for a specific file.""" + with get_db_connection() as conn: + conn.execute( + """ + DELETE FROM indexing_history + WHERE uri = ? + """, + (uri,), + ) + conn.commit() + + def delete_indexing_status_by_document_id(self, document_id: str) -> None: + """Delete indexing status for a specific document.""" + with get_db_connection() as conn: + conn.execute( + """ + DELETE FROM indexing_history + WHERE document_id = ? + """, + (document_id,), + ) + conn.commit() + + def update_indexing_status( + self, + doc: Document, + status: str, + error_message: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> None: + """Update the indexing status in the database.""" + content_hash = doc.hash + + # Get URI from metadata if available + uri = get_node_uri(doc) + if not uri: + logger.warning("URI not found for document: %s", doc.doc_id) + return + + record = IndexingHistory( + id=None, + uri=uri, + content_hash=content_hash, + status=status, + error_message=error_message, + document_id=doc.doc_id, + metadata=metadata, + ) + with get_db_connection() as conn: + # Check if record exists + existing = conn.execute( + "SELECT id FROM indexing_history WHERE document_id = ?", + (doc.doc_id,), + ).fetchone() + + if existing: + # Update existing record + conn.execute( + """ + UPDATE indexing_history + SET content_hash = ?, status = ?, error_message = ?, document_id = ?, metadata = ? + WHERE uri = ? + """, + ( + record.content_hash, + record.status, + record.error_message, + record.document_id, + json.dumps(record.metadata) if record.metadata else None, + record.uri, + ), + ) + else: + # Insert new record + conn.execute( + """ + INSERT INTO indexing_history + (uri, content_hash, status, error_message, document_id, metadata) + VALUES (?, ?, ?, ?, ?, ?) + """, + ( + record.uri, + record.content_hash, + record.status, + record.error_message, + record.document_id, + json.dumps(record.metadata) if record.metadata else None, + ), + ) + conn.commit() + + def get_indexing_status(self, doc: Document | None = None, base_uri: str | None = None) -> list[IndexingHistory]: + """Get indexing status from the database.""" + with get_db_connection() as conn: + if doc: + uri = get_node_uri(doc) + if not uri: + logger.warning("URI not found for document: %s", doc.doc_id) + return [] + content_hash = doc.hash + # For a specific file, get its latest status + query = """ + SELECT * + FROM indexing_history + WHERE uri = ? and content_hash = ? + ORDER BY timestamp DESC LIMIT 1 + """ + params = (uri, content_hash) + elif base_uri: + # For files in a specific directory, get their latest status + query = """ + WITH RankedHistory AS ( + SELECT *, + ROW_NUMBER() OVER (PARTITION BY document_id ORDER BY timestamp DESC) as rn + FROM indexing_history + WHERE uri LIKE ? || '%' + ) + SELECT id, uri, content_hash, status, timestamp, error_message, document_id, metadata + FROM RankedHistory + WHERE rn = 1 + ORDER BY timestamp DESC + """ + params = (base_uri,) if base_uri.endswith(os.path.sep) else (base_uri + os.path.sep,) + else: + # For all files, get their latest status + query = """ + WITH RankedHistory AS ( + SELECT *, + ROW_NUMBER() OVER (PARTITION BY uri ORDER BY timestamp DESC) as rn + FROM indexing_history + ) + SELECT id, uri, content_hash, status, timestamp, error_message, document_id, metadata + FROM RankedHistory + WHERE rn = 1 + ORDER BY timestamp DESC + """ + params = () + + rows = conn.execute(query, params).fetchall() + + result = [] + for row in rows: + row_dict = dict(row) + # Parse metadata JSON if it exists + if row_dict.get("metadata"): + try: + row_dict["metadata"] = json.loads(row_dict["metadata"]) + except json.JSONDecodeError: + row_dict["metadata"] = None + # Parse timestamp string to datetime if needed + if isinstance(row_dict.get("timestamp"), str): + row_dict["timestamp"] = datetime.fromisoformat( + row_dict["timestamp"].replace("Z", "+00:00"), + ) + result.append(IndexingHistory(**row_dict)) + + return result + + +indexing_history_service = IndexingHistoryService() diff --git a/py/rag-service/src/services/resource.py b/py/rag-service/src/services/resource.py new file mode 100644 index 0000000..d894504 --- /dev/null +++ b/py/rag-service/src/services/resource.py @@ -0,0 +1,104 @@ +"""Resource Service.""" + +from libs.db import get_db_connection +from models.resource import Resource + + +class ResourceService: + """Resource Service.""" + + def add_resource_to_db(self, resource: Resource) -> None: + """Add a resource to the database.""" + with get_db_connection() as conn: + conn.execute( + """ + INSERT INTO resources (name, uri, type, status, indexing_status, created_at) + VALUES (?, ?, ?, ?, ?, ?) + """, + ( + resource.name, + resource.uri, + resource.type, + resource.status, + resource.indexing_status, + resource.created_at, + ), + ) + conn.commit() + + def update_resource_indexing_status(self, uri: str, indexing_status: str, indexing_status_message: str) -> None: + """Update resource indexing status in the database.""" + with get_db_connection() as conn: + if indexing_status == "indexing": + conn.execute( + """ + UPDATE resources + SET indexing_status = ?, indexing_status_message = ?, indexing_started_at = CURRENT_TIMESTAMP + WHERE uri = ? + """, + (indexing_status, indexing_status_message, uri), + ) + else: + conn.execute( + """ + UPDATE resources + SET indexing_status = ?, indexing_status_message = ?, last_indexed_at = CURRENT_TIMESTAMP + WHERE uri = ? + """, + (indexing_status, indexing_status_message, uri), + ) + conn.commit() + + def update_resource_status(self, uri: str, status: str, error: str | None = None) -> None: + """Update resource status in the database.""" + with get_db_connection() as conn: + if status == "active": + conn.execute( + """ + UPDATE resources + SET status = ?, last_indexed_at = CURRENT_TIMESTAMP, last_error = ? + WHERE uri = ? + """, + (status, error, uri), + ) + else: + conn.execute( + """ + UPDATE resources + SET status = ?, last_error = ? + WHERE uri = ? + """, + (status, error, uri), + ) + conn.commit() + + def get_resource(self, uri: str) -> Resource | None: + """Get resource from the database.""" + with get_db_connection() as conn: + row = conn.execute( + "SELECT * FROM resources WHERE uri = ?", + (uri,), + ).fetchone() + if row: + return Resource(**dict(row)) + return None + + def get_resource_by_name(self, name: str) -> Resource | None: + """Get resource by name from the database.""" + with get_db_connection() as conn: + row = conn.execute( + "SELECT * FROM resources WHERE name = ?", + (name,), + ).fetchone() + if row: + return Resource(**dict(row)) + return None + + def get_all_resources(self) -> list[Resource]: + """Get all resources from the database.""" + with get_db_connection() as conn: + rows = conn.execute("SELECT * FROM resources ORDER BY created_at DESC").fetchall() + return [Resource(**dict(row)) for row in rows] + + +resource_service = ResourceService() diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 0000000..f51cb86 --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,25 @@ +{ + "include": [ + "." + ], + "exclude": [ + "**/node_modules", + "**/__pycache__", + "**/.*" + ], + "defineConstant": { + "DEBUG": true + }, + "venvPath": ".", + "venv": ".venv", + "pythonVersion": "3.11", + "typeCheckingMode": "strict", + "reportMissingImports": true, + "reportMissingTypeStubs": false, + "reportUnknownMemberType": false, + "reportUnknownParameterType": false, + "reportUnknownVariableType": false, + "reportUnknownArgumentType": false, + "reportPrivateUsage": false, + "reportUntypedFunctionDecorator": false +} diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..62098c1 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,49 @@ +# 与 black 保持一致的行长度 +line-length = 180 + +# 排除一些目录 +exclude = [ + ".git", + ".ruff_cache", + ".venv", + "venv", + "__pycache__", + "build", + "dist", +] + +# 目标 Python 版本 +target-version = "py312" + +[lint] +# 启用所有规则集 +select = ["ALL"] + +# 忽略一些规则 +ignore = [ + "A005", + "BLE001", + "D104", + "D100", + "D101", + "D203", # 1 blank line required before class docstring + "D212", # Multi-line docstring summary should start at the first line + "TRY300", + "TRY400", + "PGH003", + "PLR0911", +] + +# 允许使用自动修复 +fixable = ["ALL"] + +[format] +# 使用双引号 +quote-style = "double" +# 缩进风格 +indent-style = "space" + +[lint.isort] +# 与 black 兼容的导入排序设置 +combine-as-imports = true +known-first-party = ["avante"]