From 70e784a1cc3818873a1ca331f2b52ba4f0f92530 Mon Sep 17 00:00:00 2001 From: TomiEckert Date: Sun, 22 Mar 2026 19:02:37 +0100 Subject: [PATCH] add Fireworks AI provider support --- Hush.Cli/src/Commands/SetupCommand.cs | 104 ++++++--- Hush.Cli/src/Commands/ShowCommand.cs | 4 +- Hush.Daemon/src/Orchestrator.cs | 64 +++--- .../src/Providers/FireworksProvider.cs | 211 ++++++++++++++++++ 4 files changed, 310 insertions(+), 73 deletions(-) create mode 100644 Hush.Providers/src/Providers/FireworksProvider.cs diff --git a/Hush.Cli/src/Commands/SetupCommand.cs b/Hush.Cli/src/Commands/SetupCommand.cs index 2c4f7dc..70b5801 100644 --- a/Hush.Cli/src/Commands/SetupCommand.cs +++ b/Hush.Cli/src/Commands/SetupCommand.cs @@ -32,48 +32,62 @@ public static class SetupCommand AnsiConsole.MarkupLine("[bold blue]Welcome to Hush Setup![/]\n"); AnsiConsole.MarkupLine("This wizard will help you configure Hush.\n"); - var config = new HushConfig(); + var configManager = new ConfigManager(); + var config = configManager.Load(); AnsiConsole.MarkupLine("[bold]Step 1: Whisper Provider[/]"); - config.WhisperProvider = AnsiConsole.Prompt( - new SelectionPrompt() - .Title("Select Whisper provider:") - .AddChoices("groq")); + config.WhisperProvider = PromptSelection("Select Whisper provider:", ["groq", "fireworks"], config.WhisperProvider); - AnsiConsole.MarkupLine("[bold]Step 2: Groq API Key[/]"); - config.GroqApiKey = ReadMaskedInput("Enter your Groq API key:"); + AnsiConsole.MarkupLine("[bold]Step 2: LLM Provider[/]"); + config.LlmProvider = PromptSelection("Select LLM provider:", ["groq", "fireworks"], config.LlmProvider); - AnsiConsole.MarkupLine("[bold]Step 3: LLM Model[/]"); - config.LlmProvider = "groq"; - config.LlmModel = AnsiConsole.Prompt( - new SelectionPrompt() - .Title("Select LLM model:") - .AddChoices("openai/gpt-oss-20b", "llama-3.1-8b-instant", "openai/gpt-oss-120b")); + if (config.WhisperProvider == "groq" || config.LlmProvider == "groq") + { + AnsiConsole.MarkupLine("[bold]Step 3: Groq API Key[/]"); + config.GroqApiKey = ReadMaskedInput("Enter your Groq API key:", config.GroqApiKey); + } - AnsiConsole.MarkupLine("[bold]Step 4: Whisper Model[/]"); - config.WhisperModel = AnsiConsole.Prompt( - new SelectionPrompt() - .Title("Select Whisper model:") - .AddChoices("whisper-large-v3", "whisper-large-v3-turbo")); + if (config.WhisperProvider == "fireworks" || config.LlmProvider == "fireworks") + { + AnsiConsole.MarkupLine("[bold]Step 3: Fireworks API Key[/]"); + config.FireworksApiKey = ReadMaskedInput("Enter your Fireworks API key:", config.FireworksApiKey); + } - AnsiConsole.MarkupLine("[bold]Step 5: Audio Backend[/]"); - config.AudioBackend = AnsiConsole.Prompt( - new SelectionPrompt() - .Title("Select audio backend:") - .AddChoices("pipewire", "ffmpeg")); + AnsiConsole.MarkupLine("[bold]Step 4: LLM Model[/]"); + config.LlmModel = config.LlmProvider switch + { + "fireworks" => PromptSelection("Select LLM model:", [ + "accounts/fireworks/models/kimi-k2-instruct-0905", + "accounts/fireworks/models/deepseek-v3p1", + "accounts/fireworks/models/llama-v3p1-70b-instruct" + ], config.LlmModel), + _ => PromptSelection("Select LLM model:", [ + "openai/gpt-oss-20b", "llama-3.1-8b-instant", "openai/gpt-oss-120b" + ], config.LlmModel) + }; - AnsiConsole.MarkupLine("[bold]Step 6: Typing Backend[/]"); - config.TypingBackend = AnsiConsole.Prompt( - new SelectionPrompt() - .Title("Select typing backend:") - .AddChoices("wtype", "xdotool")); + AnsiConsole.MarkupLine("[bold]Step 5: Whisper Model[/]"); + config.WhisperModel = config.WhisperProvider switch + { + "fireworks" => PromptSelection("Select Whisper model:", [ + "whisper-v3", "whisper-v3-turbo" + ], config.WhisperModel), + _ => PromptSelection("Select Whisper model:", [ + "whisper-large-v3", "whisper-large-v3-turbo" + ], config.WhisperModel) + }; - AnsiConsole.MarkupLine("[bold]Step 7: Minimum Recording Duration[/]"); - var minDuration = AnsiConsole.Prompt( - new TextPrompt("Enter minimum duration (ms, default 500):") - .DefaultValue(500) + AnsiConsole.MarkupLine("[bold]Step 6: Audio Backend[/]"); + config.AudioBackend = PromptSelection("Select audio backend:", ["pipewire", "ffmpeg"], config.AudioBackend); + + AnsiConsole.MarkupLine("[bold]Step 7: Typing Backend[/]"); + config.TypingBackend = PromptSelection("Select typing backend:", ["wtype", "xdotool"], config.TypingBackend); + + AnsiConsole.MarkupLine("[bold]Step 8: Minimum Recording Duration[/]"); + config.MinRecordingDuration = AnsiConsole.Prompt( + new TextPrompt("Enter minimum duration (ms):") + .DefaultValue(config.MinRecordingDuration) .Validate(x => x > 0, "Must be greater than 0")); - config.MinRecordingDuration = minDuration; AnsiConsole.WriteLine(); AnsiConsole.MarkupLine("[bold]Configuration Summary:[/]"); @@ -82,7 +96,11 @@ public static class SetupCommand table.AddColumn("Setting"); table.AddColumn("Value"); table.AddRow("Whisper Provider", config.WhisperProvider); - table.AddRow("Groq API Key", MaskApiKey(config.GroqApiKey)); + table.AddRow("LLM Provider", config.LlmProvider); + if (!string.IsNullOrEmpty(config.GroqApiKey)) + table.AddRow("Groq API Key", MaskApiKey(config.GroqApiKey)); + if (!string.IsNullOrEmpty(config.FireworksApiKey)) + table.AddRow("Fireworks API Key", MaskApiKey(config.FireworksApiKey)); table.AddRow("LLM Model", config.LlmModel); table.AddRow("Whisper Model", config.WhisperModel); table.AddRow("Audio Backend", config.AudioBackend); @@ -94,9 +112,20 @@ public static class SetupCommand return config; } - private static string ReadMaskedInput(string prompt) + private static string PromptSelection(string title, string[] choices, string current) { - AnsiConsole.Write(prompt + " "); + var prompt = new SelectionPrompt().Title(title); + // Put current value first so it starts highlighted; add it if it's a custom value not in the list + prompt.AddChoice(current); + foreach (var choice in choices.Where(c => c != current)) + prompt.AddChoice(choice); + return AnsiConsole.Prompt(prompt); + } + + private static string ReadMaskedInput(string prompt, string current = "") + { + var hint = string.IsNullOrEmpty(current) ? "" : $" [{MaskApiKey(current)}]"; + AnsiConsole.Write(prompt + hint + " "); var input = ""; while (true) { @@ -117,7 +146,8 @@ public static class SetupCommand AnsiConsole.Write("*"); } } - return input; + // Keep existing key if user just hit Enter without typing anything + return string.IsNullOrEmpty(input) ? current : input; } private static string MaskApiKey(string key) diff --git a/Hush.Cli/src/Commands/ShowCommand.cs b/Hush.Cli/src/Commands/ShowCommand.cs index bf8704d..4f355d3 100644 --- a/Hush.Cli/src/Commands/ShowCommand.cs +++ b/Hush.Cli/src/Commands/ShowCommand.cs @@ -36,6 +36,7 @@ public static class ShowCommand { AnsiConsole.WriteLine($"whisper_provider={config.WhisperProvider}"); AnsiConsole.WriteLine($"groq_api_key={MaskApiKey(config.GroqApiKey)}"); + AnsiConsole.WriteLine($"fireworks_api_key={MaskApiKey(config.FireworksApiKey)}"); AnsiConsole.WriteLine($"llm_provider={config.LlmProvider}"); AnsiConsole.WriteLine($"llm_model={config.LlmModel}"); AnsiConsole.WriteLine($"whisper_model={config.WhisperModel}"); @@ -50,8 +51,9 @@ public static class ShowCommand table.AddColumn("Value"); table.AddRow("Whisper Provider", config.WhisperProvider); - table.AddRow("Groq API Key", MaskApiKey(config.GroqApiKey)); table.AddRow("LLM Provider", config.LlmProvider); + table.AddRow("Groq API Key", MaskApiKey(config.GroqApiKey)); + table.AddRow("Fireworks API Key", MaskApiKey(config.FireworksApiKey)); table.AddRow("LLM Model", config.LlmModel); table.AddRow("Whisper Model", config.WhisperModel); table.AddRow("Audio Backend", config.AudioBackend); diff --git a/Hush.Daemon/src/Orchestrator.cs b/Hush.Daemon/src/Orchestrator.cs index 1c312f4..0948b86 100644 --- a/Hush.Daemon/src/Orchestrator.cs +++ b/Hush.Daemon/src/Orchestrator.cs @@ -8,11 +8,10 @@ namespace Hush.Daemon; public class Orchestrator { + private static readonly HttpClient _httpClient = new(); + private readonly ConfigManager _configManager; private readonly IAudioRecorder _recorder; - private IAudioToTextProvider? _audioToTextProvider; - private ITextStreamingProvider? _textProvider; - private ITextInput? _textInput; private string? _recordingPath; private DateTime? _recordingStartTime; @@ -149,9 +148,19 @@ public class Orchestrator var provider = GetTextProvider(config); var prompt = $""" - Process this spoken text for clarity and correctness. Fix any errors, add proper punctuation, and make it read naturally. Keep the original meaning intact. + You are a transcription post-processor. Your task is to clean up raw speech-to-text output and return polished, ready-to-type text. - Text: {text} + Rules: + - Detect the language of the transcription and process it entirely in that language — do not translate + - Fix grammar, spelling, and punctuation errors introduced by the speech recognizer, following the conventions of the detected language + - Capitalize sentences and proper nouns appropriately for the detected language + - Remove filler words and false starts appropriate to the detected language (e.g. "um", "uh", "like" in English; "euh", "bah" in French; "äh", "ähm" in German; "eh", "tipo" in Spanish/Italian) + - Preserve the speaker's original intent, vocabulary choices, and tone + - Do not add, remove, or reinterpret content beyond what was said + - Do not include any explanation, preamble, or metadata — output only the corrected text + - If the input is empty or unintelligible, return an empty string + + Raw transcription: {text} """; return await provider.CompleteTextAsync(prompt, config.LlmModel); @@ -163,51 +172,36 @@ public class Orchestrator await input.TypeString(text); } - private IAudioToTextProvider GetAudioToTextProvider(HushConfig config) - { - if (_audioToTextProvider != null) - return _audioToTextProvider; - - _audioToTextProvider = config.WhisperProvider switch + private IAudioToTextProvider GetAudioToTextProvider(HushConfig config) => + config.WhisperProvider switch { - "groq" => string.IsNullOrEmpty(config.GroqApiKey) + "groq" => string.IsNullOrEmpty(config.GroqApiKey) ? throw new InvalidOperationException("Groq API key is required for Whisper transcription") - : new GroqProvider(config.GroqApiKey), + : new GroqProvider(config.GroqApiKey, _httpClient), + "fireworks" => string.IsNullOrEmpty(config.FireworksApiKey) + ? throw new InvalidOperationException("Fireworks API key is required for Whisper transcription") + : new FireworksProvider(config.FireworksApiKey, _httpClient), _ => throw new InvalidOperationException($"Unsupported Whisper provider: {config.WhisperProvider}") }; - - return _audioToTextProvider; - } - private ITextStreamingProvider GetTextProvider(HushConfig config) - { - if (_textProvider != null) - return _textProvider; - - _textProvider = config.LlmProvider switch + private ITextStreamingProvider GetTextProvider(HushConfig config) => + config.LlmProvider switch { "groq" => string.IsNullOrEmpty(config.GroqApiKey) ? throw new InvalidOperationException("Groq API key is required for LLM") - : new GroqProvider(config.GroqApiKey), + : new GroqProvider(config.GroqApiKey, _httpClient), + "fireworks" => string.IsNullOrEmpty(config.FireworksApiKey) + ? throw new InvalidOperationException("Fireworks API key is required for LLM") + : new FireworksProvider(config.FireworksApiKey, _httpClient), _ => throw new InvalidOperationException($"Unsupported LLM provider: {config.LlmProvider}") }; - - return _textProvider; - } - private ITextInput GetTextInput(HushConfig config) - { - if (_textInput != null) - return _textInput; - - _textInput = config.TypingBackend switch + private static ITextInput GetTextInput(HushConfig config) => + config.TypingBackend switch { "xdotool" => new XdotoolInput(), _ => new WtypeInput() }; - - return _textInput; - } private IAudioRecorder CreateAudioRecorder() { diff --git a/Hush.Providers/src/Providers/FireworksProvider.cs b/Hush.Providers/src/Providers/FireworksProvider.cs new file mode 100644 index 0000000..8980c4a --- /dev/null +++ b/Hush.Providers/src/Providers/FireworksProvider.cs @@ -0,0 +1,211 @@ +using System.Net.Http.Headers; +using System.Text; +using System.Text.Json; +using Hush.Providers.Interfaces; +using Hush.Providers.Models.Request; +using Hush.Providers.Serialization; + +namespace Hush.Providers.Providers; + +/// +/// Implementation of LLM provider for Fireworks AI API. +/// +public class FireworksProvider : IAudioToTextProvider, ITextStreamingProvider +{ + private const string CHAT_COMPLETION_ENDPOINT = "https://api.fireworks.ai/inference/v1/chat/completions"; + private const string TRANSCRIPTION_ENDPOINT_PROD = "https://audio-prod.api.fireworks.ai/v1/audio/transcriptions"; + private const string TRANSCRIPTION_ENDPOINT_TURBO = "https://audio-turbo.api.fireworks.ai/v1/audio/transcriptions"; + + private readonly HttpClient _httpClient; + private readonly string _apiKey; + + /// + /// Initializes a new instance of the FireworksProvider class. + /// + /// The Fireworks AI API key + /// Optional HttpClient instance (for testing) + public FireworksProvider(string apiKey, HttpClient? httpClient = null) + { + _apiKey = apiKey ?? throw new ArgumentNullException(nameof(apiKey)); + _httpClient = httpClient ?? new HttpClient(); + } + + /// + public async Task TranscribeAsync( + Stream audioStream, + string modelName, + CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(audioStream); + + if (string.IsNullOrWhiteSpace(modelName)) + throw new ArgumentException("Model name is required", nameof(modelName)); + + var endpoint = modelName.Contains("turbo", StringComparison.OrdinalIgnoreCase) + ? TRANSCRIPTION_ENDPOINT_TURBO + : TRANSCRIPTION_ENDPOINT_PROD; + + var request = new TranscriptionRequest { Model = modelName }; + + using var content = new MultipartFormDataContent(); + content.Add(new StreamContent(audioStream), "file", "audio.wav"); + content.Add(new StringContent(request.Model), "model"); + if (request.ResponseFormat != null) + content.Add(new StringContent(request.ResponseFormat), "response_format"); + if (request.Language != null) + content.Add(new StringContent(request.Language), "language"); + if (request.Prompt != null) + content.Add(new StringContent(request.Prompt), "prompt"); + if (request.Temperature.HasValue) + content.Add(new StringContent(request.Temperature.Value.ToString(System.Globalization.CultureInfo.InvariantCulture)), "temperature"); + + var httpRequest = new HttpRequestMessage(HttpMethod.Post, endpoint) + { + Content = content + }; + + httpRequest.Headers.TryAddWithoutValidation("Authorization", _apiKey); + + using var response = await _httpClient.SendAsync(httpRequest, cancellationToken).ConfigureAwait(false); + + response.EnsureSuccessStatusCode(); + + var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); + + var result = JsonSerializer.Deserialize( + responseContent, + JsonSourceGeneration.Default.TranscriptionResponse); + + if (result == null) + throw new InvalidOperationException("Failed to deserialize transcription response"); + + return result.Text; + } + + /// + public async Task CompleteTextAsync( + string prompt, + string modelName, + CancellationToken cancellationToken = default) + { + if (string.IsNullOrWhiteSpace(prompt)) + throw new ArgumentException("Prompt is required", nameof(prompt)); + + if (string.IsNullOrWhiteSpace(modelName)) + throw new ArgumentException("Model name is required", nameof(modelName)); + + var request = new ChatCompletionRequest + { + Model = modelName, + Messages = new List { new() { Role = "user", Content = prompt } } + }; + + var jsonContent = new StringContent( + JsonSerializer.Serialize(request, JsonSourceGeneration.Default.ChatCompletionRequest), + Encoding.UTF8, + "application/json"); + + var httpRequest = new HttpRequestMessage(HttpMethod.Post, CHAT_COMPLETION_ENDPOINT) + { + Content = jsonContent + }; + + httpRequest.Headers.TryAddWithoutValidation("Authorization", _apiKey); + + using var response = await _httpClient.SendAsync(httpRequest, cancellationToken).ConfigureAwait(false); + + response.EnsureSuccessStatusCode(); + + var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); + + var result = JsonSerializer.Deserialize(responseContent, JsonSourceGeneration.Default.ChatCompletionResponse); + + if (result == null || result.Choices.Count == 0) + throw new InvalidOperationException("Failed to deserialize chat completion response"); + + return result.Choices[0].Message.Content; + } + + /// + public async IAsyncEnumerable StreamTextAsync( + string prompt, + string modelName, + [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default) + { + if (string.IsNullOrWhiteSpace(prompt)) + throw new ArgumentException("Prompt is required", nameof(prompt)); + + if (string.IsNullOrWhiteSpace(modelName)) + throw new ArgumentException("Model name is required", nameof(modelName)); + + var request = new ChatCompletionRequest + { + Model = modelName, + Stream = true, + Messages = new List { new() { Role = "user", Content = prompt } } + }; + + var jsonContent = new StringContent( + JsonSerializer.Serialize(request, JsonSourceGeneration.Default.ChatCompletionRequest), + Encoding.UTF8, + "application/json"); + + var httpRequest = new HttpRequestMessage(HttpMethod.Post, CHAT_COMPLETION_ENDPOINT) + { + Content = jsonContent + }; + + httpRequest.Headers.TryAddWithoutValidation("Authorization", _apiKey); + + using var response = await _httpClient.SendAsync(httpRequest, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false); + + response.EnsureSuccessStatusCode(); + + using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + using var reader = new StreamReader(stream); + + string? line; + while ((line = await reader.ReadLineAsync(cancellationToken).ConfigureAwait(false)) != null) + { + if (string.IsNullOrWhiteSpace(line) || !line.StartsWith("data: ")) + continue; + + var data = line.Substring(6).Trim(); // Remove "data: " prefix + + if (data == "[DONE]") + break; + + var text = ParseTextFromStreamData(data); + if (!string.IsNullOrEmpty(text)) + yield return text; + } + } + + private static string? ParseTextFromStreamData(string data) + { + try + { + using var jsonDoc = JsonDocument.Parse(data); + var choices = jsonDoc.RootElement.GetProperty("choices"); + var choice = choices[0]; + + if (choice.TryGetProperty("delta", out var delta)) + { + if (delta.TryGetProperty("content", out var content)) + { + return content.GetString(); + } + } + else if (choice.TryGetProperty("text", out var text)) + { + return text.GetString(); + } + } + catch (JsonException) + { + // Skip malformed JSON chunks + } + + return null; + } +}