diff --git a/lua/avante/model_selector.lua b/lua/avante/model_selector.lua index fcadbd5..f2fff4a 100644 --- a/lua/avante/model_selector.lua +++ b/lua/avante/model_selector.lua @@ -75,6 +75,9 @@ function M.open() ::continue:: end + -- Sort models by name for stable display + table.sort(models, function(a, b) return (a.name or "") < (b.name or "") end) + if #models == 0 then Utils.warn("No models available in config") return diff --git a/lua/avante/providers/ollama.lua b/lua/avante/providers/ollama.lua index ac18fcd..71ba625 100644 --- a/lua/avante/providers/ollama.lua +++ b/lua/avante/providers/ollama.lua @@ -233,4 +233,59 @@ M.on_error = function(result) Utils.error(error_msg, { title = "Ollama" }) end +-- List available models using Ollama's tags API +function M:models_list() + -- Return cached models if available + if self._model_list_cache then return self._model_list_cache end + + -- Parse provider config and construct tags endpoint URL + local provider_conf = Providers.parse_config(self) + if not provider_conf.endpoint then error("Ollama requires endpoint configuration") end + + local curl = require("plenary.curl") + local tags_url = Utils.url_join(provider_conf.endpoint, "/api/tags") + local base_headers = { + ["Content-Type"] = "application/json", + ["Accept"] = "application/json", + } + local headers = Utils.tbl_override(base_headers, self.extra_headers) + + -- Request the model tags from Ollama + local response = curl.get(tags_url, { headers = headers }) + if response.status ~= 200 then + Utils.error("Failed to fetch Ollama models: " .. (response.body or response.status)) + return {} + end + + -- Parse the response body + local ok, res_body = pcall(vim.json.decode, response.body) + if not ok or not res_body.models then return {} end + + -- Helper to format model display string from its details + local function format_display_name(details) + local parts = {} + for _, key in ipairs({ "family", "parameter_size", "quantization_level" }) do + if details[key] then table.insert(parts, details[key]) end + end + return table.concat(parts, ", ") + end + + -- Format the models list + local models = {} + for _, model in ipairs(res_body.models) do + local details = model.details or {} + local display = format_display_name(details) + table.insert(models, { + id = model.name, + name = string.format("ollama/%s (%s)", model.name, display), + display_name = model.name, + provider_name = "ollama", + version = model.digest, + }) + end + + self._model_list_cache = models + return models +end + return M diff --git a/lua/avante/types.lua b/lua/avante/types.lua index 76ca830..d94308d 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -291,6 +291,19 @@ vim.g.avante_login = vim.g.avante_login ---@alias AvanteLLMStopCallback fun(opts: AvanteLLMStopCallbackOptions): nil ---@alias AvanteLLMConfigHandler fun(opts: AvanteSupportedProvider): AvanteDefaultBaseProvider, table --- +---@class AvanteProviderModel +---@field id string +---@field name string +---@field display_name string +---@field provider_name string +---@field version string +---@field tokenizer? string +---@field max_input_tokens? integer +---@field max_output_tokens? integer +---@field policy? boolean +--- +---@alias AvanteProviderModelList AvanteProviderModel[] +--- ---@class AvanteProvider: AvanteSupportedProvider ---@field parse_curl_args? AvanteCurlArgsParser ---@field parse_stream_data? AvanteStreamParser @@ -315,6 +328,7 @@ vim.g.avante_login = vim.g.avante_login ---@field on_error? fun(result: table): nil ---@field transform_tool? fun(self: AvanteProviderFunctor, tool: AvanteLLMTool): AvanteOpenAITool | AvanteClaudeTool ---@field get_rate_limit_sleep_time? fun(self: AvanteProviderFunctor, headers: table): integer | nil +---@field models_list? fun(self): AvanteProviderModelList | nil --- ---@alias AvanteBedrockPayloadBuilder fun(self: AvanteBedrockModelHandler | AvanteBedrockProviderFunctor, prompt_opts: AvantePromptOptions, request_body: table): table ---