feat: add support for ollama RAG providers (#1427)
* fix: openai env * feat: add support for multiple RAG providers - Added provider, model and endpoint configuration options for RAG service - Updated RAG service to support both OpenAI and Ollama providers - Added Ollama embedding support and dependencies - Improved environment variable handling for RAG service configuration Signed-off-by: wfhtqp@gmail.com <wfhtqp@gmail.com> * fix: update docker env * feat: rag server add ollama llm * fix: pre-commit * feat: check embed model and clean * docs: add rag server config docs * fix: pyright ignore --------- Signed-off-by: wfhtqp@gmail.com <wfhtqp@gmail.com>
This commit is contained in:
20
README.md
20
README.md
@@ -30,9 +30,9 @@
|
|||||||
>
|
>
|
||||||
> 🥰 This project is undergoing rapid iterations, and many exciting features will be added successively. Stay tuned!
|
> 🥰 This project is undergoing rapid iterations, and many exciting features will be added successively. Stay tuned!
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/510e6270-b6cf-459d-9a2f-15b397d1fe53
|
<https://github.com/user-attachments/assets/510e6270-b6cf-459d-9a2f-15b397d1fe53>
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/86140bfd-08b4-483d-a887-1b701d9e37dd
|
<https://github.com/user-attachments/assets/86140bfd-08b4-483d-a887-1b701d9e37dd>
|
||||||
|
|
||||||
## Sponsorship
|
## Sponsorship
|
||||||
|
|
||||||
@@ -275,7 +275,7 @@ require('avante').setup ({
|
|||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
>
|
>
|
||||||
> Any rendering plugins that support markdown should work with Avante as long as you add the supported filetype `Avante`. See https://github.com/yetone/avante.nvim/issues/175 and [this comment](https://github.com/yetone/avante.nvim/issues/175#issuecomment-2313749363) for more information.
|
> Any rendering plugins that support markdown should work with Avante as long as you add the supported filetype `Avante`. See <https://github.com/yetone/avante.nvim/issues/175> and [this comment](https://github.com/yetone/avante.nvim/issues/175#issuecomment-2313749363) for more information.
|
||||||
|
|
||||||
### Default setup configuration
|
### Default setup configuration
|
||||||
|
|
||||||
@@ -404,7 +404,9 @@ _See [config.lua#L9](./lua/avante/config.lua) for the full config_
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## Blink.cmp users
|
## Blink.cmp users
|
||||||
|
|
||||||
For blink cmp users (nvim-cmp alternative) view below instruction for configuration
|
For blink cmp users (nvim-cmp alternative) view below instruction for configuration
|
||||||
This is achieved by emulating nvim-cmp using blink.compat
|
This is achieved by emulating nvim-cmp using blink.compat
|
||||||
or you can use [Kaiser-Yang/blink-cmp-avante](https://github.com/Kaiser-Yang/blink-cmp-avante).
|
or you can use [Kaiser-Yang/blink-cmp-avante](https://github.com/Kaiser-Yang/blink-cmp-avante).
|
||||||
@@ -471,6 +473,7 @@ To create a customized file_selector, you can specify a customized function to l
|
|||||||
|
|
||||||
Choose a selector other that native, the default as that currently has an issue
|
Choose a selector other that native, the default as that currently has an issue
|
||||||
For lazyvim users copy the full config for blink.cmp from the website or extend the options
|
For lazyvim users copy the full config for blink.cmp from the website or extend the options
|
||||||
|
|
||||||
```lua
|
```lua
|
||||||
compat = {
|
compat = {
|
||||||
"avante_commands",
|
"avante_commands",
|
||||||
@@ -478,7 +481,9 @@ For lazyvim users copy the full config for blink.cmp from the website or extend
|
|||||||
"avante_files",
|
"avante_files",
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
For other users just add a custom provider
|
For other users just add a custom provider
|
||||||
|
|
||||||
```lua
|
```lua
|
||||||
default = {
|
default = {
|
||||||
...
|
...
|
||||||
@@ -487,6 +492,7 @@ For other users just add a custom provider
|
|||||||
"avante_files",
|
"avante_files",
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
```lua
|
```lua
|
||||||
providers = {
|
providers = {
|
||||||
avante_commands = {
|
avante_commands = {
|
||||||
@@ -510,6 +516,7 @@ For other users just add a custom provider
|
|||||||
...
|
...
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
@@ -561,6 +568,7 @@ Given its early stage, `avante.nvim` currently supports the following basic func
|
|||||||
> export BEDROCK_KEYS=aws_access_key_id,aws_secret_access_key,aws_region[,aws_session_token]
|
> export BEDROCK_KEYS=aws_access_key_id,aws_secret_access_key,aws_region[,aws_session_token]
|
||||||
>
|
>
|
||||||
> ```
|
> ```
|
||||||
|
>
|
||||||
> Note: The aws_session_token is optional and only needed when using temporary AWS credentials
|
> Note: The aws_session_token is optional and only needed when using temporary AWS credentials
|
||||||
|
|
||||||
1. Open a code file in Neovim.
|
1. Open a code file in Neovim.
|
||||||
@@ -649,7 +657,11 @@ Avante provides a RAG service, which is a tool for obtaining the required contex
|
|||||||
|
|
||||||
```lua
|
```lua
|
||||||
rag_service = {
|
rag_service = {
|
||||||
enabled = true, -- Enables the rag service, requires OPENAI_API_KEY to be set
|
enabled = false, -- Enables the RAG service, requires OPENAI_API_KEY to be set
|
||||||
|
provider = "openai", -- The provider to use for RAG service (e.g. openai or ollama)
|
||||||
|
llm_model = "", -- The LLM model to use for RAG service
|
||||||
|
embed_model = "", -- The embedding model to use for RAG service
|
||||||
|
endpoint = "https://api.openai.com/v1", -- The API endpoint for RAG service
|
||||||
},
|
},
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -35,6 +35,10 @@ M._defaults = {
|
|||||||
tokenizer = "tiktoken",
|
tokenizer = "tiktoken",
|
||||||
rag_service = {
|
rag_service = {
|
||||||
enabled = false, -- Enables the rag service, requires OPENAI_API_KEY to be set
|
enabled = false, -- Enables the rag service, requires OPENAI_API_KEY to be set
|
||||||
|
provider = "openai", -- The provider to use for RAG service. eg: openai or ollama
|
||||||
|
llm_model = "", -- The LLM model to use for RAG service
|
||||||
|
embed_model = "", -- The embedding model to use for RAG service
|
||||||
|
endpoint = "https://api.openai.com/v1", -- The API endpoint for RAG service
|
||||||
},
|
},
|
||||||
web_search_engine = {
|
web_search_engine = {
|
||||||
provider = "tavily",
|
provider = "tavily",
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
local curl = require("plenary.curl")
|
local curl = require("plenary.curl")
|
||||||
local Path = require("plenary.path")
|
local Path = require("plenary.path")
|
||||||
local Utils = require("avante.utils")
|
local Utils = require("avante.utils")
|
||||||
|
local Config = require("avante.config")
|
||||||
|
|
||||||
local M = {}
|
local M = {}
|
||||||
|
|
||||||
@@ -32,12 +33,12 @@ end
|
|||||||
---@param cb fun()
|
---@param cb fun()
|
||||||
function M.launch_rag_service(cb)
|
function M.launch_rag_service(cb)
|
||||||
local openai_api_key = os.getenv("OPENAI_API_KEY")
|
local openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||||
if openai_api_key == nil then
|
if Config.rag_service.provider == "openai" then
|
||||||
error("cannot launch avante rag service, OPENAI_API_KEY is not set")
|
if openai_api_key == nil then
|
||||||
return
|
error("cannot launch avante rag service, OPENAI_API_KEY is not set")
|
||||||
|
return
|
||||||
|
end
|
||||||
end
|
end
|
||||||
local openai_base_url = os.getenv("OPENAI_BASE_URL")
|
|
||||||
if openai_base_url == nil then openai_base_url = "https://api.openai.com/v1" end
|
|
||||||
local port = M.get_rag_service_port()
|
local port = M.get_rag_service_port()
|
||||||
local image = M.get_rag_service_image()
|
local image = M.get_rag_service_image()
|
||||||
local data_path = M.get_data_path()
|
local data_path = M.get_data_path()
|
||||||
@@ -63,13 +64,17 @@ function M.launch_rag_service(cb)
|
|||||||
Utils.debug(string.format("container %s not found, starting...", container_name))
|
Utils.debug(string.format("container %s not found, starting...", container_name))
|
||||||
end
|
end
|
||||||
local cmd_ = string.format(
|
local cmd_ = string.format(
|
||||||
"docker run -d -p %d:8000 --name %s -v %s:/data -v /:/host -e DATA_DIR=/data -e OPENAI_API_KEY=%s -e OPENAI_API_BASE=%s -e OPENAI_EMBED_MODEL=%s %s",
|
"docker run -d -p %d:8000 --name %s -v %s:/data -v /:/host -e DATA_DIR=/data -e RAG_PROVIDER=%s -e %s_API_KEY=%s -e %s_API_BASE=%s -e RAG_LLM_MODEL=%s -e RAG_EMBED_MODEL=%s %s",
|
||||||
port,
|
port,
|
||||||
container_name,
|
container_name,
|
||||||
data_path,
|
data_path,
|
||||||
|
Config.rag_service.provider,
|
||||||
|
Config.rag_service.provider:upper(),
|
||||||
openai_api_key,
|
openai_api_key,
|
||||||
openai_base_url,
|
Config.rag_service.provider:upper(),
|
||||||
os.getenv("OPENAI_EMBED_MODEL"),
|
Config.rag_service.endpoint,
|
||||||
|
Config.rag_service.llm_model,
|
||||||
|
Config.rag_service.embed_model,
|
||||||
image
|
image
|
||||||
)
|
)
|
||||||
vim.fn.jobstart(cmd_, {
|
vim.fn.jobstart(cmd_, {
|
||||||
@@ -229,6 +234,7 @@ function M.retrieve(base_uri, query)
|
|||||||
query = query,
|
query = query,
|
||||||
top_k = 10,
|
top_k = 10,
|
||||||
}),
|
}),
|
||||||
|
timeout = 100000,
|
||||||
})
|
})
|
||||||
if resp.status ~= 200 then
|
if resp.status ~= 200 then
|
||||||
Utils.error("failed to retrieve: " .. resp.body)
|
Utils.error("failed to retrieve: " .. resp.body)
|
||||||
|
|||||||
@@ -61,8 +61,10 @@ llama-index-agent-openai==0.4.3
|
|||||||
llama-index-cli==0.4.0
|
llama-index-cli==0.4.0
|
||||||
llama-index-core==0.12.16.post1
|
llama-index-core==0.12.16.post1
|
||||||
llama-index-embeddings-openai==0.3.1
|
llama-index-embeddings-openai==0.3.1
|
||||||
|
llama-index-embeddings-ollama==0.5.0
|
||||||
llama-index-indices-managed-llama-cloud==0.6.4
|
llama-index-indices-managed-llama-cloud==0.6.4
|
||||||
llama-index-llms-openai==0.3.18
|
llama-index-llms-openai==0.3.18
|
||||||
|
llama-index-llms-ollama==0.5.2
|
||||||
llama-index-multi-modal-llms-openai==0.4.3
|
llama-index-multi-modal-llms-openai==0.4.3
|
||||||
llama-index-program-openai==0.3.1
|
llama-index-program-openai==0.3.1
|
||||||
llama-index-question-gen-openai==0.3.0
|
llama-index-question-gen-openai==0.3.0
|
||||||
@@ -163,3 +165,4 @@ websockets==14.2
|
|||||||
wrapt==1.17.2
|
wrapt==1.17.2
|
||||||
yarl==1.18.3
|
yarl==1.18.3
|
||||||
zipp==3.21.0
|
zipp==3.21.0
|
||||||
|
docx2txt==0.8.0
|
||||||
|
|||||||
@@ -44,7 +44,10 @@ from llama_index.core import (
|
|||||||
)
|
)
|
||||||
from llama_index.core.node_parser import CodeSplitter
|
from llama_index.core.node_parser import CodeSplitter
|
||||||
from llama_index.core.schema import Document
|
from llama_index.core.schema import Document
|
||||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||||
|
from llama_index.embeddings.openai import OpenAIEmbedding, OpenAIEmbeddingModelType
|
||||||
|
from llama_index.llms.ollama import Ollama
|
||||||
|
from llama_index.llms.openai import OpenAI
|
||||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||||
from markdownify import markdownify as md
|
from markdownify import markdownify as md
|
||||||
from models.indexing_history import IndexingHistory # noqa: TC002
|
from models.indexing_history import IndexingHistory # noqa: TC002
|
||||||
@@ -311,14 +314,57 @@ init_db()
|
|||||||
|
|
||||||
# Initialize ChromaDB and LlamaIndex services
|
# Initialize ChromaDB and LlamaIndex services
|
||||||
chroma_client = chromadb.PersistentClient(path=str(CHROMA_PERSIST_DIR))
|
chroma_client = chromadb.PersistentClient(path=str(CHROMA_PERSIST_DIR))
|
||||||
chroma_collection = chroma_client.get_or_create_collection("documents")
|
|
||||||
|
# Check if provider or model has changed
|
||||||
|
current_provider = os.getenv("RAG_PROVIDER", "openai").lower()
|
||||||
|
current_embed_model = os.getenv("RAG_EMBED_MODEL", "")
|
||||||
|
current_llm_model = os.getenv("RAG_LLM_MODEL", "")
|
||||||
|
|
||||||
|
# Try to read previous config
|
||||||
|
config_file = BASE_DATA_DIR / "rag_config.json"
|
||||||
|
if config_file.exists():
|
||||||
|
with Path.open(config_file, "r") as f:
|
||||||
|
prev_config = json.load(f)
|
||||||
|
if prev_config.get("provider") != current_provider or prev_config.get("embed_model") != current_embed_model:
|
||||||
|
# Clear existing data if config changed
|
||||||
|
logger.info("Detected config change, clearing existing data...")
|
||||||
|
chroma_client.reset()
|
||||||
|
|
||||||
|
# Save current config
|
||||||
|
with Path.open(config_file, "w") as f:
|
||||||
|
json.dump({"provider": current_provider, "embed_model": current_embed_model}, f)
|
||||||
|
|
||||||
|
chroma_collection = chroma_client.get_or_create_collection("documents") # pyright: ignore
|
||||||
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
||||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||||
embed_model = OpenAIEmbedding()
|
|
||||||
model = os.getenv("OPENAI_EMBED_MODEL", "")
|
# Initialize embedding model based on provider
|
||||||
if model:
|
llm_provider = current_provider
|
||||||
embed_model = OpenAIEmbedding(model=model)
|
base_url = os.getenv(llm_provider.upper() + "_API_BASE", "")
|
||||||
|
rag_embed_model = current_embed_model
|
||||||
|
rag_llm_model = current_llm_model
|
||||||
|
|
||||||
|
if llm_provider == "ollama":
|
||||||
|
if base_url == "":
|
||||||
|
base_url = "http://localhost:11434"
|
||||||
|
if rag_embed_model == "":
|
||||||
|
rag_embed_model = "nomic-embed-text"
|
||||||
|
if rag_llm_model == "":
|
||||||
|
rag_llm_model = "llama3"
|
||||||
|
embed_model = OllamaEmbedding(model_name=rag_embed_model, base_url=base_url)
|
||||||
|
llm_model = Ollama(model=rag_llm_model, base_url=base_url, request_timeout=60.0)
|
||||||
|
else:
|
||||||
|
if base_url == "":
|
||||||
|
base_url = "https://api.openai.com/v1"
|
||||||
|
if rag_embed_model == "":
|
||||||
|
rag_embed_model = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002
|
||||||
|
if rag_llm_model == "":
|
||||||
|
rag_llm_model = "gpt-3.5-turbo"
|
||||||
|
embed_model = OpenAIEmbedding(model=rag_embed_model, api_base=base_url)
|
||||||
|
llm_model = OpenAI(model=rag_llm_model, api_base=base_url)
|
||||||
|
|
||||||
Settings.embed_model = embed_model
|
Settings.embed_model = embed_model
|
||||||
|
Settings.llm = llm_model
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user