From de7cccd0898d9256c35b94b8290accbef360c874 Mon Sep 17 00:00:00 2001 From: nzlov Date: Tue, 4 Mar 2025 11:07:40 +0800 Subject: [PATCH] feat: add support for ollama RAG providers (#1427) * fix: openai env * feat: add support for multiple RAG providers - Added provider, model and endpoint configuration options for RAG service - Updated RAG service to support both OpenAI and Ollama providers - Added Ollama embedding support and dependencies - Improved environment variable handling for RAG service configuration Signed-off-by: wfhtqp@gmail.com * fix: update docker env * feat: rag server add ollama llm * fix: pre-commit * feat: check embed model and clean * docs: add rag server config docs * fix: pyright ignore --------- Signed-off-by: wfhtqp@gmail.com --- README.md | 20 +++++++++--- lua/avante/config.lua | 4 +++ lua/avante/rag_service.lua | 22 ++++++++----- py/rag-service/requirements.txt | 3 ++ py/rag-service/src/main.py | 58 +++++++++++++++++++++++++++++---- 5 files changed, 89 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index f56a07c..a6afb30 100644 --- a/README.md +++ b/README.md @@ -30,9 +30,9 @@ > > 🥰 This project is undergoing rapid iterations, and many exciting features will be added successively. Stay tuned! -https://github.com/user-attachments/assets/510e6270-b6cf-459d-9a2f-15b397d1fe53 + -https://github.com/user-attachments/assets/86140bfd-08b4-483d-a887-1b701d9e37dd + ## Sponsorship @@ -275,7 +275,7 @@ require('avante').setup ({ > [!TIP] > -> Any rendering plugins that support markdown should work with Avante as long as you add the supported filetype `Avante`. See https://github.com/yetone/avante.nvim/issues/175 and [this comment](https://github.com/yetone/avante.nvim/issues/175#issuecomment-2313749363) for more information. +> Any rendering plugins that support markdown should work with Avante as long as you add the supported filetype `Avante`. See and [this comment](https://github.com/yetone/avante.nvim/issues/175#issuecomment-2313749363) for more information. ### Default setup configuration @@ -404,7 +404,9 @@ _See [config.lua#L9](./lua/avante/config.lua) for the full config_ }, } ``` + ## Blink.cmp users + For blink cmp users (nvim-cmp alternative) view below instruction for configuration This is achieved by emulating nvim-cmp using blink.compat or you can use [Kaiser-Yang/blink-cmp-avante](https://github.com/Kaiser-Yang/blink-cmp-avante). @@ -471,6 +473,7 @@ To create a customized file_selector, you can specify a customized function to l Choose a selector other that native, the default as that currently has an issue For lazyvim users copy the full config for blink.cmp from the website or extend the options + ```lua compat = { "avante_commands", @@ -478,7 +481,9 @@ For lazyvim users copy the full config for blink.cmp from the website or extend "avante_files", } ``` + For other users just add a custom provider + ```lua default = { ... @@ -487,6 +492,7 @@ For other users just add a custom provider "avante_files", } ``` + ```lua providers = { avante_commands = { @@ -510,6 +516,7 @@ For other users just add a custom provider ... } ``` + ## Usage @@ -561,6 +568,7 @@ Given its early stage, `avante.nvim` currently supports the following basic func > export BEDROCK_KEYS=aws_access_key_id,aws_secret_access_key,aws_region[,aws_session_token] > > ``` +> > Note: The aws_session_token is optional and only needed when using temporary AWS credentials 1. Open a code file in Neovim. @@ -649,7 +657,11 @@ Avante provides a RAG service, which is a tool for obtaining the required contex ```lua rag_service = { - enabled = true, -- Enables the rag service, requires OPENAI_API_KEY to be set + enabled = false, -- Enables the RAG service, requires OPENAI_API_KEY to be set + provider = "openai", -- The provider to use for RAG service (e.g. openai or ollama) + llm_model = "", -- The LLM model to use for RAG service + embed_model = "", -- The embedding model to use for RAG service + endpoint = "https://api.openai.com/v1", -- The API endpoint for RAG service }, ``` diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 7113632..c865004 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -35,6 +35,10 @@ M._defaults = { tokenizer = "tiktoken", rag_service = { enabled = false, -- Enables the rag service, requires OPENAI_API_KEY to be set + provider = "openai", -- The provider to use for RAG service. eg: openai or ollama + llm_model = "", -- The LLM model to use for RAG service + embed_model = "", -- The embedding model to use for RAG service + endpoint = "https://api.openai.com/v1", -- The API endpoint for RAG service }, web_search_engine = { provider = "tavily", diff --git a/lua/avante/rag_service.lua b/lua/avante/rag_service.lua index cbec6ec..ae601fc 100644 --- a/lua/avante/rag_service.lua +++ b/lua/avante/rag_service.lua @@ -1,6 +1,7 @@ local curl = require("plenary.curl") local Path = require("plenary.path") local Utils = require("avante.utils") +local Config = require("avante.config") local M = {} @@ -32,12 +33,12 @@ end ---@param cb fun() function M.launch_rag_service(cb) local openai_api_key = os.getenv("OPENAI_API_KEY") - if openai_api_key == nil then - error("cannot launch avante rag service, OPENAI_API_KEY is not set") - return + if Config.rag_service.provider == "openai" then + if openai_api_key == nil then + error("cannot launch avante rag service, OPENAI_API_KEY is not set") + return + end end - local openai_base_url = os.getenv("OPENAI_BASE_URL") - if openai_base_url == nil then openai_base_url = "https://api.openai.com/v1" end local port = M.get_rag_service_port() local image = M.get_rag_service_image() local data_path = M.get_data_path() @@ -63,13 +64,17 @@ function M.launch_rag_service(cb) Utils.debug(string.format("container %s not found, starting...", container_name)) end local cmd_ = string.format( - "docker run -d -p %d:8000 --name %s -v %s:/data -v /:/host -e DATA_DIR=/data -e OPENAI_API_KEY=%s -e OPENAI_API_BASE=%s -e OPENAI_EMBED_MODEL=%s %s", + "docker run -d -p %d:8000 --name %s -v %s:/data -v /:/host -e DATA_DIR=/data -e RAG_PROVIDER=%s -e %s_API_KEY=%s -e %s_API_BASE=%s -e RAG_LLM_MODEL=%s -e RAG_EMBED_MODEL=%s %s", port, container_name, data_path, + Config.rag_service.provider, + Config.rag_service.provider:upper(), openai_api_key, - openai_base_url, - os.getenv("OPENAI_EMBED_MODEL"), + Config.rag_service.provider:upper(), + Config.rag_service.endpoint, + Config.rag_service.llm_model, + Config.rag_service.embed_model, image ) vim.fn.jobstart(cmd_, { @@ -229,6 +234,7 @@ function M.retrieve(base_uri, query) query = query, top_k = 10, }), + timeout = 100000, }) if resp.status ~= 200 then Utils.error("failed to retrieve: " .. resp.body) diff --git a/py/rag-service/requirements.txt b/py/rag-service/requirements.txt index 78b0cda..10c0096 100644 --- a/py/rag-service/requirements.txt +++ b/py/rag-service/requirements.txt @@ -61,8 +61,10 @@ llama-index-agent-openai==0.4.3 llama-index-cli==0.4.0 llama-index-core==0.12.16.post1 llama-index-embeddings-openai==0.3.1 +llama-index-embeddings-ollama==0.5.0 llama-index-indices-managed-llama-cloud==0.6.4 llama-index-llms-openai==0.3.18 +llama-index-llms-ollama==0.5.2 llama-index-multi-modal-llms-openai==0.4.3 llama-index-program-openai==0.3.1 llama-index-question-gen-openai==0.3.0 @@ -163,3 +165,4 @@ websockets==14.2 wrapt==1.17.2 yarl==1.18.3 zipp==3.21.0 +docx2txt==0.8.0 diff --git a/py/rag-service/src/main.py b/py/rag-service/src/main.py index 1050c06..e33ccb5 100644 --- a/py/rag-service/src/main.py +++ b/py/rag-service/src/main.py @@ -44,7 +44,10 @@ from llama_index.core import ( ) from llama_index.core.node_parser import CodeSplitter from llama_index.core.schema import Document -from llama_index.embeddings.openai import OpenAIEmbedding +from llama_index.embeddings.ollama import OllamaEmbedding +from llama_index.embeddings.openai import OpenAIEmbedding, OpenAIEmbeddingModelType +from llama_index.llms.ollama import Ollama +from llama_index.llms.openai import OpenAI from llama_index.vector_stores.chroma import ChromaVectorStore from markdownify import markdownify as md from models.indexing_history import IndexingHistory # noqa: TC002 @@ -311,14 +314,57 @@ init_db() # Initialize ChromaDB and LlamaIndex services chroma_client = chromadb.PersistentClient(path=str(CHROMA_PERSIST_DIR)) -chroma_collection = chroma_client.get_or_create_collection("documents") + +# Check if provider or model has changed +current_provider = os.getenv("RAG_PROVIDER", "openai").lower() +current_embed_model = os.getenv("RAG_EMBED_MODEL", "") +current_llm_model = os.getenv("RAG_LLM_MODEL", "") + +# Try to read previous config +config_file = BASE_DATA_DIR / "rag_config.json" +if config_file.exists(): + with Path.open(config_file, "r") as f: + prev_config = json.load(f) + if prev_config.get("provider") != current_provider or prev_config.get("embed_model") != current_embed_model: + # Clear existing data if config changed + logger.info("Detected config change, clearing existing data...") + chroma_client.reset() + +# Save current config +with Path.open(config_file, "w") as f: + json.dump({"provider": current_provider, "embed_model": current_embed_model}, f) + +chroma_collection = chroma_client.get_or_create_collection("documents") # pyright: ignore vector_store = ChromaVectorStore(chroma_collection=chroma_collection) storage_context = StorageContext.from_defaults(vector_store=vector_store) -embed_model = OpenAIEmbedding() -model = os.getenv("OPENAI_EMBED_MODEL", "") -if model: - embed_model = OpenAIEmbedding(model=model) + +# Initialize embedding model based on provider +llm_provider = current_provider +base_url = os.getenv(llm_provider.upper() + "_API_BASE", "") +rag_embed_model = current_embed_model +rag_llm_model = current_llm_model + +if llm_provider == "ollama": + if base_url == "": + base_url = "http://localhost:11434" + if rag_embed_model == "": + rag_embed_model = "nomic-embed-text" + if rag_llm_model == "": + rag_llm_model = "llama3" + embed_model = OllamaEmbedding(model_name=rag_embed_model, base_url=base_url) + llm_model = Ollama(model=rag_llm_model, base_url=base_url, request_timeout=60.0) +else: + if base_url == "": + base_url = "https://api.openai.com/v1" + if rag_embed_model == "": + rag_embed_model = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002 + if rag_llm_model == "": + rag_llm_model = "gpt-3.5-turbo" + embed_model = OpenAIEmbedding(model=rag_embed_model, api_base=base_url) + llm_model = OpenAI(model=rag_llm_model, api_base=base_url) + Settings.embed_model = embed_model +Settings.llm = llm_model try: