feat: add support for ollama RAG providers (#1427)

* fix: openai env

* feat: add support for multiple RAG providers

- Added provider, model and endpoint configuration options for RAG service

- Updated RAG service to support both OpenAI and Ollama providers

- Added Ollama embedding support and dependencies

- Improved environment variable handling for RAG service configuration

Signed-off-by: wfhtqp@gmail.com <wfhtqp@gmail.com>

* fix: update docker env

* feat: rag server add ollama llm

* fix: pre-commit

* feat: check embed model and clean

* docs: add rag server config docs

* fix: pyright ignore

---------

Signed-off-by: wfhtqp@gmail.com <wfhtqp@gmail.com>
This commit is contained in:
nzlov
2025-03-04 11:07:40 +08:00
committed by GitHub
parent b01121bc39
commit de7cccd089
5 changed files with 89 additions and 18 deletions

View File

@@ -30,9 +30,9 @@
> >
> 🥰 This project is undergoing rapid iterations, and many exciting features will be added successively. Stay tuned! > 🥰 This project is undergoing rapid iterations, and many exciting features will be added successively. Stay tuned!
https://github.com/user-attachments/assets/510e6270-b6cf-459d-9a2f-15b397d1fe53 <https://github.com/user-attachments/assets/510e6270-b6cf-459d-9a2f-15b397d1fe53>
https://github.com/user-attachments/assets/86140bfd-08b4-483d-a887-1b701d9e37dd <https://github.com/user-attachments/assets/86140bfd-08b4-483d-a887-1b701d9e37dd>
## Sponsorship ## Sponsorship
@@ -275,7 +275,7 @@ require('avante').setup ({
> [!TIP] > [!TIP]
> >
> Any rendering plugins that support markdown should work with Avante as long as you add the supported filetype `Avante`. See https://github.com/yetone/avante.nvim/issues/175 and [this comment](https://github.com/yetone/avante.nvim/issues/175#issuecomment-2313749363) for more information. > Any rendering plugins that support markdown should work with Avante as long as you add the supported filetype `Avante`. See <https://github.com/yetone/avante.nvim/issues/175> and [this comment](https://github.com/yetone/avante.nvim/issues/175#issuecomment-2313749363) for more information.
### Default setup configuration ### Default setup configuration
@@ -404,7 +404,9 @@ _See [config.lua#L9](./lua/avante/config.lua) for the full config_
}, },
} }
``` ```
## Blink.cmp users ## Blink.cmp users
For blink cmp users (nvim-cmp alternative) view below instruction for configuration For blink cmp users (nvim-cmp alternative) view below instruction for configuration
This is achieved by emulating nvim-cmp using blink.compat This is achieved by emulating nvim-cmp using blink.compat
or you can use [Kaiser-Yang/blink-cmp-avante](https://github.com/Kaiser-Yang/blink-cmp-avante). or you can use [Kaiser-Yang/blink-cmp-avante](https://github.com/Kaiser-Yang/blink-cmp-avante).
@@ -471,6 +473,7 @@ To create a customized file_selector, you can specify a customized function to l
Choose a selector other that native, the default as that currently has an issue Choose a selector other that native, the default as that currently has an issue
For lazyvim users copy the full config for blink.cmp from the website or extend the options For lazyvim users copy the full config for blink.cmp from the website or extend the options
```lua ```lua
compat = { compat = {
"avante_commands", "avante_commands",
@@ -478,7 +481,9 @@ For lazyvim users copy the full config for blink.cmp from the website or extend
"avante_files", "avante_files",
} }
``` ```
For other users just add a custom provider For other users just add a custom provider
```lua ```lua
default = { default = {
... ...
@@ -487,6 +492,7 @@ For other users just add a custom provider
"avante_files", "avante_files",
} }
``` ```
```lua ```lua
providers = { providers = {
avante_commands = { avante_commands = {
@@ -510,6 +516,7 @@ For other users just add a custom provider
... ...
} }
``` ```
</details> </details>
## Usage ## Usage
@@ -561,6 +568,7 @@ Given its early stage, `avante.nvim` currently supports the following basic func
> export BEDROCK_KEYS=aws_access_key_id,aws_secret_access_key,aws_region[,aws_session_token] > export BEDROCK_KEYS=aws_access_key_id,aws_secret_access_key,aws_region[,aws_session_token]
> >
> ``` > ```
>
> Note: The aws_session_token is optional and only needed when using temporary AWS credentials > Note: The aws_session_token is optional and only needed when using temporary AWS credentials
1. Open a code file in Neovim. 1. Open a code file in Neovim.
@@ -649,7 +657,11 @@ Avante provides a RAG service, which is a tool for obtaining the required contex
```lua ```lua
rag_service = { rag_service = {
enabled = true, -- Enables the rag service, requires OPENAI_API_KEY to be set enabled = false, -- Enables the RAG service, requires OPENAI_API_KEY to be set
provider = "openai", -- The provider to use for RAG service (e.g. openai or ollama)
llm_model = "", -- The LLM model to use for RAG service
embed_model = "", -- The embedding model to use for RAG service
endpoint = "https://api.openai.com/v1", -- The API endpoint for RAG service
}, },
``` ```

View File

@@ -35,6 +35,10 @@ M._defaults = {
tokenizer = "tiktoken", tokenizer = "tiktoken",
rag_service = { rag_service = {
enabled = false, -- Enables the rag service, requires OPENAI_API_KEY to be set enabled = false, -- Enables the rag service, requires OPENAI_API_KEY to be set
provider = "openai", -- The provider to use for RAG service. eg: openai or ollama
llm_model = "", -- The LLM model to use for RAG service
embed_model = "", -- The embedding model to use for RAG service
endpoint = "https://api.openai.com/v1", -- The API endpoint for RAG service
}, },
web_search_engine = { web_search_engine = {
provider = "tavily", provider = "tavily",

View File

@@ -1,6 +1,7 @@
local curl = require("plenary.curl") local curl = require("plenary.curl")
local Path = require("plenary.path") local Path = require("plenary.path")
local Utils = require("avante.utils") local Utils = require("avante.utils")
local Config = require("avante.config")
local M = {} local M = {}
@@ -32,12 +33,12 @@ end
---@param cb fun() ---@param cb fun()
function M.launch_rag_service(cb) function M.launch_rag_service(cb)
local openai_api_key = os.getenv("OPENAI_API_KEY") local openai_api_key = os.getenv("OPENAI_API_KEY")
if openai_api_key == nil then if Config.rag_service.provider == "openai" then
error("cannot launch avante rag service, OPENAI_API_KEY is not set") if openai_api_key == nil then
return error("cannot launch avante rag service, OPENAI_API_KEY is not set")
return
end
end end
local openai_base_url = os.getenv("OPENAI_BASE_URL")
if openai_base_url == nil then openai_base_url = "https://api.openai.com/v1" end
local port = M.get_rag_service_port() local port = M.get_rag_service_port()
local image = M.get_rag_service_image() local image = M.get_rag_service_image()
local data_path = M.get_data_path() local data_path = M.get_data_path()
@@ -63,13 +64,17 @@ function M.launch_rag_service(cb)
Utils.debug(string.format("container %s not found, starting...", container_name)) Utils.debug(string.format("container %s not found, starting...", container_name))
end end
local cmd_ = string.format( local cmd_ = string.format(
"docker run -d -p %d:8000 --name %s -v %s:/data -v /:/host -e DATA_DIR=/data -e OPENAI_API_KEY=%s -e OPENAI_API_BASE=%s -e OPENAI_EMBED_MODEL=%s %s", "docker run -d -p %d:8000 --name %s -v %s:/data -v /:/host -e DATA_DIR=/data -e RAG_PROVIDER=%s -e %s_API_KEY=%s -e %s_API_BASE=%s -e RAG_LLM_MODEL=%s -e RAG_EMBED_MODEL=%s %s",
port, port,
container_name, container_name,
data_path, data_path,
Config.rag_service.provider,
Config.rag_service.provider:upper(),
openai_api_key, openai_api_key,
openai_base_url, Config.rag_service.provider:upper(),
os.getenv("OPENAI_EMBED_MODEL"), Config.rag_service.endpoint,
Config.rag_service.llm_model,
Config.rag_service.embed_model,
image image
) )
vim.fn.jobstart(cmd_, { vim.fn.jobstart(cmd_, {
@@ -229,6 +234,7 @@ function M.retrieve(base_uri, query)
query = query, query = query,
top_k = 10, top_k = 10,
}), }),
timeout = 100000,
}) })
if resp.status ~= 200 then if resp.status ~= 200 then
Utils.error("failed to retrieve: " .. resp.body) Utils.error("failed to retrieve: " .. resp.body)

View File

@@ -61,8 +61,10 @@ llama-index-agent-openai==0.4.3
llama-index-cli==0.4.0 llama-index-cli==0.4.0
llama-index-core==0.12.16.post1 llama-index-core==0.12.16.post1
llama-index-embeddings-openai==0.3.1 llama-index-embeddings-openai==0.3.1
llama-index-embeddings-ollama==0.5.0
llama-index-indices-managed-llama-cloud==0.6.4 llama-index-indices-managed-llama-cloud==0.6.4
llama-index-llms-openai==0.3.18 llama-index-llms-openai==0.3.18
llama-index-llms-ollama==0.5.2
llama-index-multi-modal-llms-openai==0.4.3 llama-index-multi-modal-llms-openai==0.4.3
llama-index-program-openai==0.3.1 llama-index-program-openai==0.3.1
llama-index-question-gen-openai==0.3.0 llama-index-question-gen-openai==0.3.0
@@ -163,3 +165,4 @@ websockets==14.2
wrapt==1.17.2 wrapt==1.17.2
yarl==1.18.3 yarl==1.18.3
zipp==3.21.0 zipp==3.21.0
docx2txt==0.8.0

View File

@@ -44,7 +44,10 @@ from llama_index.core import (
) )
from llama_index.core.node_parser import CodeSplitter from llama_index.core.node_parser import CodeSplitter
from llama_index.core.schema import Document from llama_index.core.schema import Document
from llama_index.embeddings.openai import OpenAIEmbedding from llama_index.embeddings.ollama import OllamaEmbedding
from llama_index.embeddings.openai import OpenAIEmbedding, OpenAIEmbeddingModelType
from llama_index.llms.ollama import Ollama
from llama_index.llms.openai import OpenAI
from llama_index.vector_stores.chroma import ChromaVectorStore from llama_index.vector_stores.chroma import ChromaVectorStore
from markdownify import markdownify as md from markdownify import markdownify as md
from models.indexing_history import IndexingHistory # noqa: TC002 from models.indexing_history import IndexingHistory # noqa: TC002
@@ -311,14 +314,57 @@ init_db()
# Initialize ChromaDB and LlamaIndex services # Initialize ChromaDB and LlamaIndex services
chroma_client = chromadb.PersistentClient(path=str(CHROMA_PERSIST_DIR)) chroma_client = chromadb.PersistentClient(path=str(CHROMA_PERSIST_DIR))
chroma_collection = chroma_client.get_or_create_collection("documents")
# Check if provider or model has changed
current_provider = os.getenv("RAG_PROVIDER", "openai").lower()
current_embed_model = os.getenv("RAG_EMBED_MODEL", "")
current_llm_model = os.getenv("RAG_LLM_MODEL", "")
# Try to read previous config
config_file = BASE_DATA_DIR / "rag_config.json"
if config_file.exists():
with Path.open(config_file, "r") as f:
prev_config = json.load(f)
if prev_config.get("provider") != current_provider or prev_config.get("embed_model") != current_embed_model:
# Clear existing data if config changed
logger.info("Detected config change, clearing existing data...")
chroma_client.reset()
# Save current config
with Path.open(config_file, "w") as f:
json.dump({"provider": current_provider, "embed_model": current_embed_model}, f)
chroma_collection = chroma_client.get_or_create_collection("documents") # pyright: ignore
vector_store = ChromaVectorStore(chroma_collection=chroma_collection) vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
storage_context = StorageContext.from_defaults(vector_store=vector_store) storage_context = StorageContext.from_defaults(vector_store=vector_store)
embed_model = OpenAIEmbedding()
model = os.getenv("OPENAI_EMBED_MODEL", "") # Initialize embedding model based on provider
if model: llm_provider = current_provider
embed_model = OpenAIEmbedding(model=model) base_url = os.getenv(llm_provider.upper() + "_API_BASE", "")
rag_embed_model = current_embed_model
rag_llm_model = current_llm_model
if llm_provider == "ollama":
if base_url == "":
base_url = "http://localhost:11434"
if rag_embed_model == "":
rag_embed_model = "nomic-embed-text"
if rag_llm_model == "":
rag_llm_model = "llama3"
embed_model = OllamaEmbedding(model_name=rag_embed_model, base_url=base_url)
llm_model = Ollama(model=rag_llm_model, base_url=base_url, request_timeout=60.0)
else:
if base_url == "":
base_url = "https://api.openai.com/v1"
if rag_embed_model == "":
rag_embed_model = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002
if rag_llm_model == "":
rag_llm_model = "gpt-3.5-turbo"
embed_model = OpenAIEmbedding(model=rag_embed_model, api_base=base_url)
llm_model = OpenAI(model=rag_llm_model, api_base=base_url)
Settings.embed_model = embed_model Settings.embed_model = embed_model
Settings.llm = llm_model
try: try: