From b28d8998f7ec726a3b905c5de1df60775aaa6911 Mon Sep 17 00:00:00 2001 From: TomiEckert Date: Wed, 18 Mar 2026 22:16:28 +0100 Subject: [PATCH] feat: parallel async processing and compact output mode Major performance improvements: - Parallel search execution across all queries - Parallel article fetching with 10 concurrent limit - Parallel embeddings with rate limiting (4 concurrent) - Polly integration for retry resilience New features: - Add -v/--verbose flag for detailed output - Compact single-line status mode with braille spinner - StatusReporter service for unified output handling - Query generation and errors hidden in compact mode - ANSI escape codes for clean line updates New files: - Services/RateLimiter.cs - Semaphore-based concurrency control - Services/StatusReporter.cs - Verbose/compact output handler - Models/ParallelOptions.cs - Parallel processing configuration All changes maintain Native AOT compatibility. --- Models/OpenQueryOptions.cs | 3 +- Models/ParallelOptions.cs | 8 ++ OpenQuery.cs | 119 ++++++++++++------- OpenQuery.csproj | 4 +- Program.cs | 12 +- Services/EmbeddingService.cs | 158 +++++++++++++++++++++++--- Services/RateLimiter.cs | 42 +++++++ Services/StatusReporter.cs | 128 +++++++++++++++++++++ Tools/SearchTool.cs | 214 ++++++++++++++++++++++++++++------- 9 files changed, 579 insertions(+), 109 deletions(-) create mode 100644 Models/ParallelOptions.cs create mode 100644 Services/RateLimiter.cs create mode 100644 Services/StatusReporter.cs diff --git a/Models/OpenQueryOptions.cs b/Models/OpenQueryOptions.cs index 5195b03..dff1a1a 100644 --- a/Models/OpenQueryOptions.cs +++ b/Models/OpenQueryOptions.cs @@ -6,5 +6,6 @@ public record OpenQueryOptions( int Queries, bool Short, bool Long, + bool Verbose, string Question -); \ No newline at end of file +); diff --git a/Models/ParallelOptions.cs b/Models/ParallelOptions.cs new file mode 100644 index 0000000..0dab9d6 --- /dev/null +++ b/Models/ParallelOptions.cs @@ -0,0 +1,8 @@ +namespace OpenQuery.Models; + +public class ParallelProcessingOptions +{ + public int MaxConcurrentArticleFetches { get; set; } = 10; + public int MaxConcurrentEmbeddingRequests { get; set; } = 4; + public int EmbeddingBatchSize { get; set; } = 300; +} diff --git a/OpenQuery.cs b/OpenQuery.cs index 93404a9..89ce93f 100644 --- a/OpenQuery.cs +++ b/OpenQuery.cs @@ -12,7 +12,6 @@ public class OpenQueryApp private readonly OpenRouterClient _client; private readonly SearchTool _searchTool; private readonly string _model; - private static readonly char[] Function = ['|', '/', '-', '\\']; public OpenQueryApp( OpenRouterClient client, @@ -26,12 +25,22 @@ public class OpenQueryApp public async Task RunAsync(OpenQueryOptions options) { + using var reporter = new StatusReporter(options.Verbose); + reporter.StartSpinner(); + var queries = new List { options.Question }; if (options.Queries > 1) { - Console.WriteLine($"[Generating {options.Queries} search queries based on your question...]"); - + if (options.Verbose) + { + reporter.WriteLine($"[Generating {options.Queries} search queries based on your question...]"); + } + else + { + reporter.UpdateStatus("Generating search queries..."); + } + var queryGenMessages = new List { new Message("system", """ @@ -63,23 +72,67 @@ public class OpenQueryApp if (!string.IsNullOrEmpty(content)) { content = Regex.Replace(content, @"```json\s*|\s*```", "").Trim(); - + var generatedQueries = JsonSerializer.Deserialize(content, AppJsonContext.Default.ListString); if (generatedQueries != null && generatedQueries.Count > 0) { queries = generatedQueries; - Console.WriteLine($"[Generated queries: {string.Join(", ", queries)}]"); + if (options.Verbose) + { + reporter.WriteLine($"[Generated queries: {string.Join(", ", queries)}]"); + } } } } catch (Exception ex) { - Console.WriteLine($"[Failed to generate queries, falling back to original question. Error: {ex.Message}]"); + if (options.Verbose) + { + reporter.WriteLine($"[Failed to generate queries, falling back to original question. Error: {ex.Message}]"); + } } } - var searchResult = await _searchTool.ExecuteAsync(options.Question, queries, options.Results, options.Chunks, msg => Console.WriteLine(msg)); - Console.WriteLine(); + reporter.UpdateStatus("Searching web..."); + var searchResult = await _searchTool.ExecuteAsync( + options.Question, + queries, + options.Results, + options.Chunks, + (progress) => { + if (options.Verbose) + { + reporter.WriteLine(progress); + } + else + { + // Parse progress messages for compact mode + if (progress.StartsWith("[Fetching article") && progress.Contains("/")) + { + // Extract "X/Y" from "[Fetching article X/Y: domain]" + var match = Regex.Match(progress, @"\[(\d+)/(\d+)"); + if (match.Success) + { + reporter.UpdateStatus($"Fetching articles {match.Groups[1].Value}/{match.Groups[2].Value}..."); + } + } + else if (progress.Contains("embeddings")) + { + reporter.UpdateStatus("Processing embeddings..."); + } + } + }, + options.Verbose); + + if (!options.Verbose) + { + reporter.UpdateStatus("Asking AI..."); + } + else + { + reporter.ClearStatus(); + Console.WriteLine(); + } var systemPrompt = "You are a helpful AI assistant. Answer the user's question in depth, based on the provided context. Be precise and accurate. You can mention sources or citations."; if (options.Short) @@ -94,46 +147,27 @@ public class OpenQueryApp }; var requestStream = new ChatCompletionRequest(_model, messages); - + var assistantResponse = new StringBuilder(); var isFirstChunk = true; - Console.Write("[Sending request to AI model...] "); - - using var cts = new CancellationTokenSource(); - var spinnerTask = Task.Run(async () => - { - var spinner = Function; - var index = 0; - while (cts is { Token.IsCancellationRequested: false }) - { - if (Console.CursorLeft > 0) - { - Console.Write(spinner[index++ % spinner.Length]); - Console.SetCursorPosition(Console.CursorLeft - 1, Console.CursorTop); - } - try - { - await Task.Delay(100, cts.Token); - } - catch (TaskCanceledException) - { - break; - } - } - }, cts.Token); - try { - await foreach (var chunk in _client.StreamAsync(requestStream, cts.Token)) + using var streamCts = new CancellationTokenSource(); + await foreach (var chunk in _client.StreamAsync(requestStream, streamCts.Token)) { if (chunk.TextDelta == null) continue; if (isFirstChunk) { - await cts.CancelAsync(); - await spinnerTask; - Console.WriteLine(); - Console.Write("Assistant: "); + reporter.StopSpinner(); + if (!options.Verbose) + { + reporter.ClearStatus(); + } + else + { + Console.Write("Assistant: "); + } isFirstChunk = false; } Console.Write(chunk.TextDelta); @@ -142,12 +176,9 @@ public class OpenQueryApp } finally { - if (!cts.IsCancellationRequested) - { - await cts.CancelAsync(); - } + reporter.StopSpinner(); } Console.WriteLine(); } -} \ No newline at end of file +} diff --git a/OpenQuery.csproj b/OpenQuery.csproj index 2f2406f..be2f648 100644 --- a/OpenQuery.csproj +++ b/OpenQuery.csproj @@ -1,4 +1,4 @@ - + Exe @@ -10,6 +10,8 @@ + + diff --git a/Program.cs b/Program.cs index 54b93e9..0273590 100644 --- a/Program.cs +++ b/Program.cs @@ -34,6 +34,11 @@ var longOption = new Option( description: "Give a long elaborate detailed answer" ); +var verboseOption = new Option( + aliases: ["-v", "--verbose"], + description: "Show detailed progress information" +); + var questionArgument = new Argument( name: "question", description: "The question to ask" @@ -127,11 +132,12 @@ var rootCommand = new RootCommand("OpenQuery - AI powered search and answer") queriesOption, shortOption, longOption, + verboseOption, questionArgument, configureCommand }; -rootCommand.SetHandler(async (chunks, results, queries, isShort, isLong, questionArgs) => +rootCommand.SetHandler(async (chunks, results, queries, isShort, isLong, verbose, questionArgs) => { var question = string.Join(" ", questionArgs); if (string.IsNullOrWhiteSpace(question)) @@ -140,7 +146,7 @@ rootCommand.SetHandler(async (chunks, results, queries, isShort, isLong, questio return; } - var options = new OpenQueryOptions(chunks, results, queries, isShort, isLong, question); + var options = new OpenQueryOptions(chunks, results, queries, isShort, isLong, verbose, question); var apiKey = Environment.GetEnvironmentVariable("OPENROUTER_API_KEY"); @@ -183,6 +189,6 @@ rootCommand.SetHandler(async (chunks, results, queries, isShort, isLong, questio Console.Error.WriteLine($"\n[Error] An unexpected error occurred: {ex.Message}"); Environment.Exit(1); } -}, chunksOption, resultsOption, queriesOption, shortOption, longOption, questionArgument); +}, chunksOption, resultsOption, queriesOption, shortOption, longOption, verboseOption, questionArgument); return await rootCommand.InvokeAsync(args); \ No newline at end of file diff --git a/Services/EmbeddingService.cs b/Services/EmbeddingService.cs index e88b1d9..626c21b 100644 --- a/Services/EmbeddingService.cs +++ b/Services/EmbeddingService.cs @@ -1,4 +1,7 @@ using System.Numerics.Tensors; +using OpenQuery.Models; +using Polly; +using Polly.Retry; namespace OpenQuery.Services; @@ -6,33 +9,156 @@ public class EmbeddingService { private readonly OpenRouterClient _client; private readonly string _embeddingModel; + private readonly ParallelProcessingOptions _options; + private readonly RateLimiter _rateLimiter; + private readonly ResiliencePipeline _retryPipeline; public EmbeddingService(OpenRouterClient client, string embeddingModel = "openai/text-embedding-3-small") { _client = client; _embeddingModel = embeddingModel; + _options = new ParallelProcessingOptions(); + _rateLimiter = new RateLimiter(_options.MaxConcurrentEmbeddingRequests); + + _retryPipeline = new ResiliencePipelineBuilder() + .AddRetry(new RetryStrategyOptions + { + MaxRetryAttempts = 3, + Delay = TimeSpan.FromSeconds(1), + BackoffType = DelayBackoffType.Exponential, + ShouldHandle = new PredicateBuilder() + .Handle() + }) + .Build(); } - public async Task GetEmbeddingsAsync(List texts) + public async Task GetEmbeddingsAsync( + List texts, + Action? onProgress = null, + CancellationToken cancellationToken = default) { - var results = new List(); - const int batchSize = 300; - - for (var i = 0; i < texts.Count; i += batchSize) - { - if (texts.Count > batchSize) - Console.WriteLine( - $"[Generating {Math.Ceiling(i / (double)batchSize)}/{Math.Ceiling(texts.Count / (double)batchSize)} batch of embeddings]"); - var batch = texts.Skip(i).Take(batchSize).ToList(); - var batchResults = await _client.EmbedAsync(_embeddingModel, batch); - results.AddRange(batchResults); - } - - return results.ToArray(); + var batchSize = _options.EmbeddingBatchSize; + var totalBatches = (int)Math.Ceiling(texts.Count / (double)batchSize); + var results = new List<(int batchIndex, float[][] embeddings)>(); + + var batchIndices = Enumerable.Range(0, totalBatches).ToList(); + + await Parallel.ForEachAsync( + batchIndices, + new ParallelOptions + { + MaxDegreeOfParallelism = _options.MaxConcurrentEmbeddingRequests, + CancellationToken = cancellationToken + }, + async (batchIndex, ct) => + { + var startIndex = batchIndex * batchSize; + var batch = texts.Skip(startIndex).Take(batchSize).ToList(); + + onProgress?.Invoke($"[Generating embeddings: batch {batchIndex + 1}/{totalBatches}]"); + + try + { + var batchResults = await _rateLimiter.ExecuteAsync(async () => + await _retryPipeline.ExecuteAsync(async token => + await _client.EmbedAsync(_embeddingModel, batch), + ct), + ct); + + lock (results) + { + results.Add((batchIndex, batchResults)); + } + } + catch + { + // Skip failed batches, return empty embeddings for this batch + var emptyBatch = new float[batch.Count][]; + for (var i = 0; i < batch.Count; i++) + { + emptyBatch[i] = []; + } + lock (results) + { + results.Add((batchIndex, emptyBatch)); + } + } + }); + + // Reassemble results in order + var orderedResults = results + .OrderBy(r => r.batchIndex) + .SelectMany(r => r.embeddings) + .ToArray(); + + return orderedResults; + } + + public async Task GetEmbeddingAsync( + string text, + CancellationToken cancellationToken = default) + { + var results = await _rateLimiter.ExecuteAsync(async () => + await _retryPipeline.ExecuteAsync(async token => + await _client.EmbedAsync(_embeddingModel, [text]), + cancellationToken), + cancellationToken); + + return results[0]; + } + + public async Task GetEmbeddingsWithRateLimitAsync( + List texts, + Action? onProgress = null, + CancellationToken cancellationToken = default) + { + var batchSize = _options.EmbeddingBatchSize; + var totalBatches = (int)Math.Ceiling(texts.Count / (double)batchSize); + var results = new float[totalBatches][][]; + + var completedBatches = 0; + + await Parallel.ForEachAsync( + Enumerable.Range(0, totalBatches), + new ParallelOptions + { + MaxDegreeOfParallelism = _options.MaxConcurrentEmbeddingRequests, + CancellationToken = cancellationToken + }, + async (batchIndex, ct) => + { + var startIndex = batchIndex * batchSize; + var batch = texts.Skip(startIndex).Take(batchSize).ToList(); + + var currentBatch = Interlocked.Increment(ref completedBatches); + onProgress?.Invoke($"[Generating embeddings: batch {currentBatch}/{totalBatches}]"); + + try + { + var batchResults = await _rateLimiter.ExecuteAsync(async () => + await _retryPipeline.ExecuteAsync(async token => + await _client.EmbedAsync(_embeddingModel, batch), + ct), + ct); + + results[batchIndex] = batchResults; + } + catch + { + // Skip failed batches + results[batchIndex] = new float[batch.Count][]; + for (var i = 0; i < batch.Count; i++) + { + results[batchIndex][i] = []; + } + } + }); + + return results.SelectMany(r => r).ToArray(); } public static float CosineSimilarity(float[] vector1, float[] vector2) { return TensorPrimitives.CosineSimilarity(vector1, vector2); } -} \ No newline at end of file +} diff --git a/Services/RateLimiter.cs b/Services/RateLimiter.cs new file mode 100644 index 0000000..9df36b3 --- /dev/null +++ b/Services/RateLimiter.cs @@ -0,0 +1,42 @@ +namespace OpenQuery.Services; + +public sealed class RateLimiter : IAsyncDisposable +{ + private readonly SemaphoreSlim _semaphore; + + public RateLimiter(int maxConcurrentRequests) + { + _semaphore = new SemaphoreSlim(maxConcurrentRequests, maxConcurrentRequests); + } + + public async Task ExecuteAsync(Func> action, CancellationToken cancellationToken = default) + { + await _semaphore.WaitAsync(cancellationToken); + try + { + return await action(); + } + finally + { + _semaphore.Release(); + } + } + + public async Task ExecuteAsync(Func action, CancellationToken cancellationToken = default) + { + await _semaphore.WaitAsync(cancellationToken); + try + { + await action(); + } + finally + { + _semaphore.Release(); + } + } + + public async ValueTask DisposeAsync() + { + _semaphore.Dispose(); + } +} diff --git a/Services/StatusReporter.cs b/Services/StatusReporter.cs new file mode 100644 index 0000000..00f17b9 --- /dev/null +++ b/Services/StatusReporter.cs @@ -0,0 +1,128 @@ +using System.Threading.Channels; + +namespace OpenQuery.Services; + +public class StatusReporter : IDisposable +{ + private readonly bool _verbose; + private readonly char[] _spinnerChars = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']; + private string? _currentMessage; + private CancellationTokenSource? _spinnerCts; + private Task? _spinnerTask; + private readonly Channel _statusChannel; + private readonly Task _statusProcessor; + + public StatusReporter(bool verbose) + { + _verbose = verbose; + _statusChannel = Channel.CreateUnbounded(); + _statusProcessor = ProcessStatusUpdatesAsync(); + } + + private async Task ProcessStatusUpdatesAsync() + { + await foreach (var message in _statusChannel.Reader.ReadAllAsync()) + { + if (_verbose) + { + Console.WriteLine(message); + continue; + } + + // Clear current line using ANSI escape code + Console.Write("\r\x1b[K"); + + // Write new status with spinner (use first spinner char for static updates) + Console.Write($"{_spinnerChars[0]} {message}"); + + _currentMessage = message; + } + } + + public void UpdateStatus(string message) + { + _statusChannel.Writer.TryWrite(message); + } + + public void ClearStatus() + { + if (_verbose) return; + + Console.Write("\r\x1b[K"); + _currentMessage = null; + } + + public void WriteFinal(string text) + { + if (_verbose) + { + Console.WriteLine(text); + return; + } + + StopSpinner(); + Console.Write("\r\x1b[K"); + Console.Write(text); + Console.WriteLine(); + } + + public void StartSpinner() + { + if (_verbose || _spinnerCts != null) return; + + _spinnerCts = new CancellationTokenSource(); + _spinnerTask = Task.Run(async () => + { + var spinner = _spinnerChars; + var index = 0; + while (_spinnerCts is { Token.IsCancellationRequested: false }) + { + if (_currentMessage != null) + { + Console.Write("\r\x1b[K"); + var charIndex = index++ % spinner.Length; + Console.Write($"{spinner[charIndex]} {_currentMessage}"); + } + + try + { + await Task.Delay(100, _spinnerCts.Token); + } + catch (TaskCanceledException) + { + break; + } + } + }, _spinnerCts.Token); + } + + public void StopSpinner() + { + if (_spinnerCts == null) return; + + _spinnerCts.Cancel(); + _spinnerTask?.GetAwaiter().GetResult(); + _spinnerCts = null; + _spinnerTask = null; + } + + public void WriteLine(string text) + { + if (_verbose) + { + Console.WriteLine(text); + return; + } + + StopSpinner(); + ClearStatus(); + Console.WriteLine(text); + } + + public void Dispose() + { + _statusChannel.Writer.Complete(); + _statusProcessor.GetAwaiter().GetResult(); + StopSpinner(); + } +} diff --git a/Tools/SearchTool.cs b/Tools/SearchTool.cs index 6d81803..a36efdd 100644 --- a/Tools/SearchTool.cs +++ b/Tools/SearchTool.cs @@ -1,3 +1,4 @@ +using System.Collections.Concurrent; using OpenQuery.Models; using OpenQuery.Services; @@ -7,6 +8,7 @@ public class SearchTool { private readonly SearxngClient _searxngClient; private readonly EmbeddingService _embeddingService; + private readonly ParallelProcessingOptions _options; public static string Name => "search"; public static string Description => "Search the web for information on a topic"; @@ -17,73 +19,197 @@ public class SearchTool { _searxngClient = searxngClient; _embeddingService = embeddingService; + _options = new ParallelProcessingOptions(); } - public async Task ExecuteAsync(string originalQuery, List generatedQueries, int maxResults, int topChunksLimit, Action? onProgress = null) + public async Task ExecuteAsync( + string originalQuery, + List generatedQueries, + int maxResults, + int topChunksLimit, + Action? onProgress = null, + bool verbose = true) { - var allResults = new List(); - - foreach (var query in generatedQueries) - { - onProgress?.Invoke($"[Searching web for '{query}'...]"); - var results = await _searxngClient.SearchAsync(query, maxResults); - allResults.AddRange(results); - } + // Phase 1: Parallel Searches + var searchResults = await ExecuteParallelSearchesAsync(generatedQueries, maxResults, onProgress, verbose); - var uniqueResults = allResults.DistinctBy(r => r.Url).ToList(); - - if (uniqueResults.Count == 0) + if (searchResults.Count == 0) return "No search results found."; - onProgress?.Invoke($"[Found {uniqueResults.Count} unique results across all queries. Fetching and reading articles...]"); - var chunks = new List(); - - foreach (var result in uniqueResults) - { - try - { - var article = await ArticleService.FetchArticleAsync(result.Url); - if (!article.IsReadable || string.IsNullOrEmpty(article.TextContent)) continue; - var textChunks = ChunkingService.ChunkText(article.TextContent); - - chunks.AddRange(textChunks.Select(chunkText => new Chunk(chunkText, result.Url, article.Title))); - } - catch - { - // ignored - } - } + // Phase 2: Parallel Article Fetching + var chunks = await ExecuteParallelArticleFetchingAsync(searchResults, onProgress, verbose); if (chunks.Count == 0) return "Found search results but could not extract readable content."; - onProgress?.Invoke($"[Extracted {chunks.Count} text chunks. Generating embeddings for semantic search...]"); + // Phase 3: Parallel Embeddings with Rate Limiting + var (queryEmbedding, chunkEmbeddings) = await ExecuteParallelEmbeddingsAsync( + originalQuery, chunks, onProgress, verbose); + + // Phase 4: Ranking + var topChunks = RankAndSelectTopChunks(chunks, chunkEmbeddings, queryEmbedding, topChunksLimit); + + onProgress?.Invoke($"[Found top {topChunks.Count} most relevant chunks overall. Generating answer...]"); + + var context = string.Join("\n\n", topChunks.Select((c, i) => + $"[Source {i + 1}: {c.Title ?? "Unknown"}]({c.SourceUrl})\n{c.Content}")); + + return context; + } + + private async Task> ExecuteParallelSearchesAsync( + List generatedQueries, + int maxResults, + Action? onProgress, + bool verbose) + { + var allResults = new ConcurrentBag(); + + var searchTasks = generatedQueries.Select(async query => + { + onProgress?.Invoke($"[Searching web for '{query}'...]"); + try + { + var results = await _searxngClient.SearchAsync(query, maxResults); + foreach (var result in results) + { + allResults.Add(result); + } + } + catch (Exception ex) + { + if (verbose) + { + Console.WriteLine($"Warning: Search failed for query '{query}': {ex.Message}"); + } + } + }); + + await Task.WhenAll(searchTasks); + + var uniqueResults = allResults.DistinctBy(r => r.Url).ToList(); + return uniqueResults; + } + + private async Task> ExecuteParallelArticleFetchingAsync( + List searchResults, + Action? onProgress, + bool verbose) + { + var chunks = new ConcurrentBag(); + var completedFetches = 0; + var totalFetches = searchResults.Count; + + var semaphore = new SemaphoreSlim(_options.MaxConcurrentArticleFetches); + + var fetchTasks = searchResults.Select(async result => + { + await semaphore.WaitAsync(); + try + { + var current = Interlocked.Increment(ref completedFetches); + var uri = new Uri(result.Url); + var domain = uri.Host; + onProgress?.Invoke($"[Fetching article {current}/{totalFetches}: {domain}]"); + + try + { + var article = await ArticleService.FetchArticleAsync(result.Url); + if (!article.IsReadable || string.IsNullOrEmpty(article.TextContent)) + return; + + var textChunks = ChunkingService.ChunkText(article.TextContent); + + foreach (var chunkText in textChunks) + { + chunks.Add(new Chunk(chunkText, result.Url, article.Title)); + } + } + catch (Exception ex) + { + if (verbose) + { + Console.WriteLine($"Warning: Failed to fetch article {result.Url}: {ex.Message}"); + } + } + } + finally + { + semaphore.Release(); + } + }); + + await Task.WhenAll(fetchTasks); + + return chunks.ToList(); + } + + private async Task<(float[] queryEmbedding, float[][] chunkEmbeddings)> ExecuteParallelEmbeddingsAsync( + string originalQuery, + List chunks, + Action? onProgress, + bool verbose) + { + onProgress?.Invoke($"[Generating embeddings for {chunks.Count} chunks and query...]"); + + // Start query embedding and chunk embeddings concurrently + var queryEmbeddingTask = _embeddingService.GetEmbeddingAsync(originalQuery); + var chunkTexts = chunks.Select(c => c.Content).ToList(); - var embeddings = await _embeddingService.GetEmbeddingsAsync(chunkTexts); - + var chunkEmbeddingsTask = _embeddingService.GetEmbeddingsWithRateLimitAsync( + chunkTexts, onProgress); + + await Task.WhenAll(queryEmbeddingTask, chunkEmbeddingsTask); + + var queryEmbedding = await queryEmbeddingTask; + var chunkEmbeddings = await chunkEmbeddingsTask; + + // Filter out any chunks with empty embeddings (failed batches) + var validChunks = new List(); + var validEmbeddings = new List(); + for (var i = 0; i < chunks.Count; i++) { - chunks[i] = chunks[i] with { Embedding = embeddings[i] }; + if (chunkEmbeddings[i].Length > 0) + { + validChunks.Add(chunks[i]); + validEmbeddings.Add(chunkEmbeddings[i]); + } } - var queryEmbedding = (await _embeddingService.GetEmbeddingsAsync([originalQuery]))[0]; - - foreach (var chunk in chunks) + // Update chunks with embeddings + for (var i = 0; i < validChunks.Count; i++) + { + validChunks[i].Embedding = validEmbeddings[i]; + } + + return (queryEmbedding, validEmbeddings.ToArray()); + } + + private List RankAndSelectTopChunks( + List chunks, + float[][] chunkEmbeddings, + float[] queryEmbedding, + int topChunksLimit) + { + // Filter to only chunks that have embeddings + var chunksWithEmbeddings = chunks.Where(c => c.Embedding != null).ToList(); + + foreach (var chunk in chunksWithEmbeddings) { chunk.Score = EmbeddingService.CosineSimilarity(queryEmbedding, chunk.Embedding!); } - var topChunks = chunks.OrderByDescending(c => c.Score).Take(topChunksLimit).ToList(); - - onProgress?.Invoke($"[Found top {topChunks.Count} most relevant chunks overall. Generating answer...]"); - var context = string.Join("\n\n", topChunks.Select((c, i) => - $"[Source {i + 1}: {c.Title ?? "Unknown"}]({c.SourceUrl})\n{c.Content}")); + var topChunks = chunksWithEmbeddings + .OrderByDescending(c => c.Score) + .Take(topChunksLimit) + .ToList(); - return context; + return topChunks; } public static string Execute(string argumentsJson) { throw new InvalidOperationException("Use ExecuteAsync instead"); } -} \ No newline at end of file +}