From 750ee80971c50349b1be68755eba47171aec7044 Mon Sep 17 00:00:00 2001 From: yetone Date: Mon, 10 Mar 2025 02:23:56 +0800 Subject: [PATCH] feat: add ollama as supported provider (#1543) * feat: add ollama as supported provider *This implementation is only working with `stream = true`* - Uses the actual ollama api and allows for passing additional options - Properly passes the system prompt to api Use ollama as provider in opts like this: opts = { debug = true, provider = "ollama", ollama = { api_key_name = "", endpoint = "http://127.0.0.1:11434", model = "qwen2.5-coder:latest", options = { num_ctx = 32768, temperature = 0, }, stream = true, }, * fix: ollama types --------- Co-authored-by: jtabke <25010496+jtabke@users.noreply.github.com> --- README.md | 11 ++++++ lua/avante/config.lua | 11 +++++- lua/avante/providers/init.lua | 8 +++- lua/avante/providers/ollama.lua | 70 +++++++++++++++++++++++++++++++++ lua/avante/providers/openai.lua | 5 +-- 5 files changed, 100 insertions(+), 5 deletions(-) create mode 100644 lua/avante/providers/ollama.lua diff --git a/README.md b/README.md index 78d2a39..98c6cc6 100644 --- a/README.md +++ b/README.md @@ -683,6 +683,17 @@ return { See [highlights.lua](./lua/avante/highlights.lua) for more information +## Ollama + +ollama is a first-class provider for avante.nvim. You can use it by setting `provider = "ollama"` in the configuration, and set the `model` field in `ollama` to the model you want to use. For example: + +```lua +provider = "ollama", +ollama = { + model = "qwq:32b", +} +``` + ## Custom providers Avante provides a set of default providers, but users can also create their own providers. diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 283709b..4edb9ab 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -20,7 +20,7 @@ local M = {} ---@field custom_tools AvanteLLMToolPublic[] M._defaults = { debug = false, - ---@alias ProviderName "claude" | "openai" | "azure" | "gemini" | "vertex" | "cohere" | "copilot" | "bedrock" | string + ---@alias ProviderName "claude" | "openai" | "azure" | "gemini" | "vertex" | "cohere" | "copilot" | "bedrock" | "ollama" | string provider = "claude", -- WARNING: Since auto-suggestions are a high-frequency operation and therefore expensive, -- currently designating it as `copilot` provider is dangerous because: https://github.com/yetone/avante.nvim/issues/1048 @@ -255,6 +255,15 @@ M._defaults = { temperature = 0, max_tokens = 4096, }, + ---@type AvanteSupportedProvider + ollama = { + endpoint = "http://127.0.0.1:11434", + timeout = 30000, -- Timeout in milliseconds + options = { + temperature = 0, + num_ctx = 4096, + }, + }, ---To add support for custom provider, follow the format below ---See https://github.com/yetone/avante.nvim/wiki#custom-providers for more details ---@type {[string]: AvanteProvider} diff --git a/lua/avante/providers/init.lua b/lua/avante/providers/init.lua index 447dd65..bfc8d0e 100644 --- a/lua/avante/providers/init.lua +++ b/lua/avante/providers/init.lua @@ -18,6 +18,7 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil } ---@field gemini AvanteProviderFunctor ---@field cohere AvanteProviderFunctor ---@field bedrock AvanteBedrockProviderFunctor +---@field ollama AvanteProviderFunctor local M = {} ---@class EnvironmentHandler @@ -152,8 +153,13 @@ M = setmetatable(M, { __index = function(t, k) local provider_config = M.get_config(k) + if Config.vendors[k] ~= nil and k == "ollama" then + Utils.warn( + "ollama is now a first-class provider in avante.nvim, please stop using vendors to define ollama, for migration guide please refer to: https://github.com/yetone/avante.nvim/wiki/Custom-providers#ollama" + ) + end ---@diagnostic disable: undefined-field,no-unknown,inject-field - if Config.vendors[k] ~= nil then + if Config.vendors[k] ~= nil and k ~= "ollama" then if provider_config.parse_response_data ~= nil then Utils.error("parse_response_data is not supported for avante.nvim vendors") end diff --git a/lua/avante/providers/ollama.lua b/lua/avante/providers/ollama.lua new file mode 100644 index 0000000..9471b77 --- /dev/null +++ b/lua/avante/providers/ollama.lua @@ -0,0 +1,70 @@ +local Utils = require("avante.utils") +local P = require("avante.providers") + +---@class AvanteProviderFunctor +local M = {} + +M.api_key_name = "" -- Ollama typically doesn't require API keys for local use + +M.role_map = { + user = "user", + assistant = "assistant", +} + +M.parse_messages = P.openai.parse_messages +M.is_o_series_model = P.openai.is_o_series_model + +function M:is_disable_stream() return false end + +function M:parse_stream_data(ctx, data, handler_opts) + local ok, json_data = pcall(vim.json.decode, data) + if not ok or not json_data then + -- Add debug logging + Utils.debug("Failed to parse JSON", data) + return + end + + if json_data.message and json_data.message.content then + local content = json_data.message.content + if content and content ~= "" then handler_opts.on_chunk(content) end + end + + if json_data.done then + handler_opts.on_stop({ reason = "complete" }) + return + end +end + +---@param prompt_opts AvantePromptOptions +function M:parse_curl_args(prompt_opts) + local provider_conf, request_body = P.parse_config(self) + + if not provider_conf.model or provider_conf.model == "" then error("Ollama model must be specified in config") end + if not provider_conf.endpoint then error("Ollama requires endpoint configuration") end + + return { + url = Utils.url_join(provider_conf.endpoint, "/api/chat"), + headers = { + ["Content-Type"] = "application/json", + ["Accept"] = "application/json", + }, + body = vim.tbl_deep_extend("force", { + model = provider_conf.model, + messages = self:parse_messages(prompt_opts), + stream = true, + system = prompt_opts.system_prompt, + }, request_body), + } +end + +---@param result table +M.on_error = function(result) + local error_msg = "Ollama API error" + if result.body then + local ok, body = pcall(vim.json.decode, result.body) + if ok and body.error then error_msg = body.error end + end + Utils.error(error_msg, { title = "Ollama" }) +end + +return M diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index 83594e4..b00d8b6 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -67,12 +67,11 @@ function M.is_o_series_model(model) return model and string.match(model, "^o%d+" function M:parse_messages(opts) local messages = {} - local provider = P[Config.provider] - local base, _ = P.parse_config(provider) + local provider_conf, _ = P.parse_config(self) -- NOTE: Handle the case where the selected model is the `o1` model -- "o1" models are "smart" enough to understand user prompt as a system prompt in this context - if self.is_o_series_model(base.model) then + if self.is_o_series_model(provider_conf.model) then table.insert(messages, { role = "user", content = opts.system_prompt }) else table.insert(messages, { role = "system", content = opts.system_prompt })