add configuration profiles with per-invocation --profile flag

- Add SystemPrompt field to HushConfig (empty = built-in default)
- Refactor ConfigManager: extract ApplyTomlFields, add LoadWithProfile(),
  ListProfiles(), GetProfilePath(), EnsureProfilesDirExists(); remove
  HUSH_PROFILE env-var logic (profiles are now resolved by the CLI)
- Extend socket protocol: action commands (START/STOP/TOGGLE/ABORT) now
  carry a [4-byte LE length][optional HushConfig JSON] payload so the CLI
  can pass a per-invocation config override without restarting the daemon
- Add GENERATE_PROFILE (cmd 7) socket command: CLI sends a description,
  daemon calls the LLM and returns a generated system prompt
- Orchestrator: StopAndProcessAsync accepts optional HushConfig override;
  ProcessWithLlmAsync uses proper system/user chat roles and respects
  config.SystemPrompt; add GenerateProfilePromptAsync
- Split CompleteTextAsync signature to (systemPrompt, userMessage, model)
  across ITextStreamingProvider, GroqProvider, FireworksProvider
- Add --profile/-p flag to hush toggle and hush stop
- Add hush profiles subcommand: list, get, new (manual or AI-generated), edit
This commit is contained in:
2026-03-23 00:38:29 +01:00
parent 70e784a1cc
commit eb0619dea2
14 changed files with 659 additions and 372 deletions
+1
View File
@@ -30,6 +30,7 @@ public class Program
rootCommand.AddCommand(SetupCommand.Create()); rootCommand.AddCommand(SetupCommand.Create());
rootCommand.AddCommand(LatencyTestCommand.Create()); rootCommand.AddCommand(LatencyTestCommand.Create());
rootCommand.AddCommand(ShowCommand.Create()); rootCommand.AddCommand(ShowCommand.Create());
rootCommand.AddCommand(ProfilesCommand.Create());
return rootCommand; return rootCommand;
} }
+250
View File
@@ -0,0 +1,250 @@
using System.CommandLine;
using System.Text.Json;
using Hush.Config;
using Hush.Daemon;
using Spectre.Console;
namespace Hush.Cli.Commands;
public static class ProfilesCommand
{
private static readonly string ProfileTemplate =
"# Hush profile — only fields listed here override the base config.\n" +
"# All fields are optional. Delete any line you don't want to override.\n" +
"#\n" +
"# Available fields:\n" +
"# whisper_provider = \"groq\" # or \"fireworks\"\n" +
"# llm_provider = \"groq\" # or \"fireworks\"\n" +
"# llm_model = \"openai/gpt-oss-20b\"\n" +
"# whisper_model = \"whisper-large-v3-turbo\"\n" +
"# whisper_language = \"en\" # ISO-639-1, empty = auto-detect\n" +
"# system_prompt = \"\"\"\n" +
"# Your custom instruction for the LLM goes here.\n" +
"# Output only the final result with no explanation.\n" +
"# \"\"\"\n" +
"\n" +
"system_prompt = \"\"\"\n" +
"You are a transcription post-processor. Clean up the raw speech-to-text output\n" +
"and return polished, ready-to-type text. Fix grammar, punctuation, and remove\n" +
"filler words. Output only the corrected text with no explanation.\n" +
"\"\"\"\n";
public static Command Create()
{
var profiles = new Command("profiles", "Manage configuration profiles");
profiles.AddCommand(CreateListCommand());
profiles.AddCommand(CreateGetCommand());
profiles.AddCommand(CreateNewCommand());
profiles.AddCommand(CreateEditCommand());
return profiles;
}
// ── list ─────────────────────────────────────────────────────────────────
private static Command CreateListCommand()
{
var cmd = new Command("list", "List all available profiles");
cmd.SetHandler(() =>
{
var manager = new ConfigManager();
var profiles = manager.ListProfiles().ToList();
if (profiles.Count == 0)
{
AnsiConsole.MarkupLine("[grey]No profiles found. Use 'hush profiles new <name>' to create one.[/]");
return;
}
foreach (var name in profiles)
AnsiConsole.WriteLine(name);
});
return cmd;
}
// ── get ──────────────────────────────────────────────────────────────────
private static Command CreateGetCommand()
{
var nameArg = new Argument<string>("name", "Profile name");
var cmd = new Command("get", "Print the contents of a profile");
cmd.AddArgument(nameArg);
cmd.SetHandler((context) =>
{
var name = context.ParseResult.GetValueForArgument(nameArg);
var manager = new ConfigManager();
var path = manager.GetProfilePath(name);
if (!File.Exists(path))
{
AnsiConsole.MarkupLine($"[red]Profile '{name}' not found.[/]");
context.ExitCode = 1;
return;
}
Console.Write(File.ReadAllText(path));
});
return cmd;
}
// ── new ──────────────────────────────────────────────────────────────────
private static Command CreateNewCommand()
{
var nameArg = new Argument<string>("name", "Profile name");
var cmd = new Command("new", "Create a new profile");
cmd.AddArgument(nameArg);
cmd.SetHandler(async (context) =>
{
var name = context.ParseResult.GetValueForArgument(nameArg);
var manager = new ConfigManager();
var path = manager.GetProfilePath(name);
if (File.Exists(path))
{
AnsiConsole.MarkupLine($"[red]Profile '{name}' already exists. Use 'hush profiles edit {name}' to edit it.[/]");
context.ExitCode = 1;
return;
}
// Ask creation mode
var mode = AnsiConsole.Prompt(
new SelectionPrompt<string>()
.Title("How do you want to create this profile?")
.AddChoices("Generate with AI (describe what you want)", "Create manually (edit a template)"));
string initialContent;
if (mode.StartsWith("Generate"))
{
var description = AnsiConsole.Ask<string>("Describe what you want this profile to do:");
var generatedPrompt = await AnsiConsole.Status()
.Spinner(Spinner.Known.Dots)
.StartAsync("Generating system prompt...", async _ =>
await GenerateSystemPromptAsync(description));
if (generatedPrompt == null)
{
context.ExitCode = 1;
return;
}
initialContent =
$"# Hush profile — AI-generated for: {description}\n" +
"# Review and adjust as needed, then save and close your editor.\n" +
"\n" +
"system_prompt = \"\"\"\n" +
$"{generatedPrompt}\n" +
"\"\"\"\n";
}
else
{
initialContent = ProfileTemplate;
}
manager.EnsureProfilesDirExists();
File.WriteAllText(path, initialContent);
OpenInEditor(path);
AnsiConsole.MarkupLine($"[green]Profile '{name}' saved.[/]");
});
return cmd;
}
// ── edit ─────────────────────────────────────────────────────────────────
private static Command CreateEditCommand()
{
var nameArg = new Argument<string>("name", "Profile name");
var cmd = new Command("edit", "Edit an existing profile in $EDITOR");
cmd.AddArgument(nameArg);
cmd.SetHandler((context) =>
{
var name = context.ParseResult.GetValueForArgument(nameArg);
var manager = new ConfigManager();
var path = manager.GetProfilePath(name);
if (!File.Exists(path))
{
AnsiConsole.MarkupLine($"[red]Profile '{name}' not found. Use 'hush profiles new {name}' to create it.[/]");
context.ExitCode = 1;
return;
}
OpenInEditor(path);
});
return cmd;
}
// ── helpers ───────────────────────────────────────────────────────────────
private static async Task<string?> GenerateSystemPromptAsync(string description)
{
try
{
await using var client = new SocketClient();
await client.ConnectAsync(TimeSpan.FromSeconds(2));
var request = new GenerateProfileRequest(description);
await client.SendRequestAsync(
DaemonProtocol.GENERATE_PROFILE,
request,
DaemonJsonContext.Default.GenerateProfileRequest);
var json = await client.ReceiveRawJsonAsync(TimeSpan.FromSeconds(30));
if (json == null)
{
AnsiConsole.MarkupLine("[red]No response from daemon.[/]");
return null;
}
// Check for error response first
var error = JsonSerializer.Deserialize(json, DaemonJsonContext.Default.ErrorResponse);
if (error?.Error != null)
{
AnsiConsole.MarkupLine($"[red]Daemon error: {error.Error}[/]");
return null;
}
var result = JsonSerializer.Deserialize(json, DaemonJsonContext.Default.GenerateProfileResponse);
return result?.SystemPrompt;
}
catch (Exception ex)
{
AnsiConsole.MarkupLine($"[red]Error: {ex.Message}[/]");
return null;
}
}
private static void OpenInEditor(string path)
{
var editor = Environment.GetEnvironmentVariable("EDITOR");
if (string.IsNullOrEmpty(editor))
editor = "nano";
try
{
var process = new System.Diagnostics.Process
{
StartInfo = new System.Diagnostics.ProcessStartInfo
{
FileName = editor,
Arguments = $"\"{path}\"",
UseShellExecute = false
}
};
process.Start();
process.WaitForExit();
}
catch (Exception ex)
{
AnsiConsole.MarkupLine($"[red]Could not open editor '{editor}': {ex.Message}[/]");
AnsiConsole.MarkupLine($"Profile saved at: {path}");
}
}
}
+16
View File
@@ -1,4 +1,5 @@
using System.CommandLine; using System.CommandLine;
using Hush.Config;
using Hush.Daemon; using Hush.Daemon;
using Spectre.Console; using Spectre.Console;
@@ -9,13 +10,28 @@ public static class StopCommand
public static Command Create() public static Command Create()
{ {
var command = new Command("stop", "Stop recording and process"); var command = new Command("stop", "Stop recording and process");
var profileOption = new Option<string?>(["--profile", "-p"], "Profile name to apply when processing");
command.AddOption(profileOption);
command.SetHandler(async (context) => command.SetHandler(async (context) =>
{ {
var profileName = context.ParseResult.GetValueForOption(profileOption);
try try
{ {
await using var client = new SocketClient(); await using var client = new SocketClient();
await client.ConnectAsync(TimeSpan.FromSeconds(2)); await client.ConnectAsync(TimeSpan.FromSeconds(2));
if (!string.IsNullOrEmpty(profileName))
{
var config = new ConfigManager().LoadWithProfile(profileName);
await client.SendCommandWithConfigAsync(DaemonProtocol.STOP, config);
}
else
{
await client.SendCommandAsync(DaemonProtocol.STOP); await client.SendCommandAsync(DaemonProtocol.STOP);
}
AnsiConsole.MarkupLine("[green]Stop command sent[/]"); AnsiConsole.MarkupLine("[green]Stop command sent[/]");
} }
catch (Exception ex) catch (Exception ex)
+16
View File
@@ -1,4 +1,5 @@
using System.CommandLine; using System.CommandLine;
using Hush.Config;
using Hush.Daemon; using Hush.Daemon;
using Spectre.Console; using Spectre.Console;
@@ -9,13 +10,28 @@ public static class ToggleCommand
public static Command Create() public static Command Create()
{ {
var command = new Command("toggle", "Toggle recording (start if idle, stop if recording)"); var command = new Command("toggle", "Toggle recording (start if idle, stop if recording)");
var profileOption = new Option<string?>(["--profile", "-p"], "Profile name to apply when processing stops");
command.AddOption(profileOption);
command.SetHandler(async (context) => command.SetHandler(async (context) =>
{ {
var profileName = context.ParseResult.GetValueForOption(profileOption);
try try
{ {
await using var client = new SocketClient(); await using var client = new SocketClient();
await client.ConnectAsync(TimeSpan.FromSeconds(2)); await client.ConnectAsync(TimeSpan.FromSeconds(2));
if (!string.IsNullOrEmpty(profileName))
{
var config = new ConfigManager().LoadWithProfile(profileName);
await client.SendCommandWithConfigAsync(DaemonProtocol.TOGGLE, config);
}
else
{
await client.SendCommandAsync(DaemonProtocol.TOGGLE); await client.SendCommandAsync(DaemonProtocol.TOGGLE);
}
AnsiConsole.MarkupLine("[green]Toggle command sent[/]"); AnsiConsole.MarkupLine("[green]Toggle command sent[/]");
} }
catch (Exception ex) catch (Exception ex)
+55
View File
@@ -1,6 +1,8 @@
using System.Net.Sockets; using System.Net.Sockets;
using System.Text; using System.Text;
using System.Text.Json; using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using Hush.Config;
using Hush.Daemon; using Hush.Daemon;
namespace Hush.Cli; namespace Hush.Cli;
@@ -26,10 +28,59 @@ public class SocketClient : IAsyncDisposable
await _socket.ConnectAsync(_endPoint, cts.Token); await _socket.ConnectAsync(_endPoint, cts.Token);
} }
/// <summary>
/// Sends a command with no config payload.
/// Action commands (START/STOP/TOGGLE/ABORT) always include a 4-byte zero length prefix
/// so the daemon can read the same framing unconditionally.
/// </summary>
public async Task SendCommandAsync(byte command) public async Task SendCommandAsync(byte command)
{
if (IsActionCommand(command))
{
// [cmd][4 zero bytes] — signals no config override
var frame = new byte[5];
frame[0] = command;
await _socket.SendAsync(frame, SocketFlags.None);
}
else
{ {
await _socket.SendAsync(new[] { command }, SocketFlags.None); await _socket.SendAsync(new[] { command }, SocketFlags.None);
} }
}
/// <summary>
/// Sends an action command with a HushConfig override payload.
/// Format: [1 byte cmd][4-byte LE length][N bytes JSON]
/// </summary>
public async Task SendCommandWithConfigAsync(byte command, HushConfig config)
{
var jsonBytes = JsonSerializer.SerializeToUtf8Bytes(config, HushConfigContext.Default.HushConfig);
var lenBytes = BitConverter.GetBytes(jsonBytes.Length);
var frame = new byte[1 + 4 + jsonBytes.Length];
frame[0] = command;
lenBytes.CopyTo(frame, 1);
jsonBytes.CopyTo(frame, 5);
await _socket.SendAsync(frame, SocketFlags.None);
}
/// <summary>
/// Sends a request with a typed JSON payload (e.g. GENERATE_PROFILE).
/// Format: [1 byte cmd][4-byte LE length][N bytes JSON]
/// </summary>
public async Task SendRequestAsync<TRequest>(byte command, TRequest payload, JsonTypeInfo<TRequest> typeInfo)
{
var jsonBytes = JsonSerializer.SerializeToUtf8Bytes(payload, typeInfo);
var lenBytes = BitConverter.GetBytes(jsonBytes.Length);
var frame = new byte[1 + 4 + jsonBytes.Length];
frame[0] = command;
lenBytes.CopyTo(frame, 1);
jsonBytes.CopyTo(frame, 5);
await _socket.SendAsync(frame, SocketFlags.None);
}
public async Task<T?> ReceiveJsonAsync<T>(TimeSpan timeout) public async Task<T?> ReceiveJsonAsync<T>(TimeSpan timeout)
{ {
@@ -61,4 +112,8 @@ public class SocketClient : IAsyncDisposable
_socket.Dispose(); _socket.Dispose();
await ValueTask.CompletedTask; await ValueTask.CompletedTask;
} }
private static bool IsActionCommand(byte command) =>
command is DaemonProtocol.START or DaemonProtocol.STOP
or DaemonProtocol.ABORT or DaemonProtocol.TOGGLE;
} }
+77 -33
View File
@@ -7,53 +7,61 @@ public class ConfigManager
{ {
private readonly string _configDir; private readonly string _configDir;
private readonly string _configPath; private readonly string _configPath;
private readonly string _profilesDir;
public ConfigManager() public ConfigManager()
{ {
var homeDir = Environment.GetFolderPath(Environment.SpecialFolder.UserProfile); var homeDir = Environment.GetFolderPath(Environment.SpecialFolder.UserProfile);
_configDir = Path.Combine(homeDir, ".config", "hush"); _configDir = Path.Combine(homeDir, ".config", "hush");
_configPath = Path.Combine(_configDir, "config"); _configPath = Path.Combine(_configDir, "config");
_profilesDir = Path.Combine(_configDir, "profiles");
} }
public HushConfig Load() public HushConfig Load() => LoadFromFile(_configPath);
/// <summary>
/// Loads the base config and merges the named profile on top of it.
/// Only fields present in the profile file override the base config.
/// Silently falls back to the base config on any error.
/// </summary>
public HushConfig LoadWithProfile(string profileName)
{ {
if (!File.Exists(_configPath)) var config = LoadFromFile(_configPath);
{
return new HushConfig(); var profilePath = Path.Combine(_profilesDir, profileName);
} if (!File.Exists(profilePath))
return config;
try try
{ {
var toml = File.ReadAllText(_configPath); var profileToml = File.ReadAllText(profilePath);
var model = Toml.ToModel<TomlTable>(toml); var profileModel = Toml.ToModel<TomlTable>(profileToml);
ApplyTomlFields(profileModel, config);
var config = new HushConfig();
if (model.TryGetValue("groq_api_key", out var groqKey)) config.GroqApiKey = groqKey.ToString() ?? string.Empty;
if (model.TryGetValue("together_api_key", out var togetherKey)) config.TogetherApiKey = togetherKey.ToString() ?? string.Empty;
if (model.TryGetValue("cerebras_api_key", out var cerebrasKey)) config.CerebrasApiKey = cerebrasKey.ToString() ?? string.Empty;
if (model.TryGetValue("fireworks_api_key", out var fireworksKey)) config.FireworksApiKey = fireworksKey.ToString() ?? string.Empty;
if (model.TryGetValue("llm_provider", out var llmProvider)) config.LlmProvider = llmProvider.ToString() ?? "groq";
if (model.TryGetValue("whisper_provider", out var whisperProvider)) config.WhisperProvider = whisperProvider.ToString() ?? "groq";
if (model.TryGetValue("typing_backend", out var typingBackend)) config.TypingBackend = typingBackend.ToString() ?? "wtype";
if (model.TryGetValue("audio_backend", out var audioBackend)) config.AudioBackend = audioBackend.ToString() ?? "pw-record";
if (model.TryGetValue("llm_model", out var llmModel)) config.LlmModel = llmModel.ToString() ?? "openai/gpt-oss-20b";
if (model.TryGetValue("whisper_model", out var whisperModel)) config.WhisperModel = whisperModel.ToString() ?? "whisper-large-v3-turbo";
if (model.TryGetValue("reasoning_effort", out var reasoningEffort)) config.ReasoningEffort = reasoningEffort.ToString() ?? "none";
if (model.TryGetValue("min_recording_duration", out var minDuration)) config.MinRecordingDuration = Convert.ToInt32(minDuration);
if (model.TryGetValue("whisper_language", out var language)) config.WhisperLanguage = language.ToString() ?? string.Empty;
return config;
} }
catch catch
{ {
return new HushConfig(); // Silent fallback to base config on any profile error
} }
return config;
} }
public IEnumerable<string> ListProfiles()
{
if (!Directory.Exists(_profilesDir))
return [];
return Directory.GetFiles(_profilesDir)
.Select(Path.GetFileName)
.Where(n => n != null)
.Cast<string>()
.Order();
}
public string GetProfilePath(string profileName) => Path.Combine(_profilesDir, profileName);
public void EnsureProfilesDirExists() => Directory.CreateDirectory(_profilesDir);
public void Save(HushConfig config) public void Save(HushConfig config)
{ {
Directory.CreateDirectory(_configDir); Directory.CreateDirectory(_configDir);
@@ -61,8 +69,6 @@ public class ConfigManager
var model = new TomlTable var model = new TomlTable
{ {
["groq_api_key"] = config.GroqApiKey, ["groq_api_key"] = config.GroqApiKey,
["together_api_key"] = config.TogetherApiKey,
["cerebras_api_key"] = config.CerebrasApiKey,
["fireworks_api_key"] = config.FireworksApiKey, ["fireworks_api_key"] = config.FireworksApiKey,
["llm_provider"] = config.LlmProvider, ["llm_provider"] = config.LlmProvider,
@@ -72,13 +78,51 @@ public class ConfigManager
["llm_model"] = config.LlmModel, ["llm_model"] = config.LlmModel,
["whisper_model"] = config.WhisperModel, ["whisper_model"] = config.WhisperModel,
["reasoning_effort"] = config.ReasoningEffort,
["min_recording_duration"] = config.MinRecordingDuration, ["min_recording_duration"] = config.MinRecordingDuration,
["whisper_language"] = config.WhisperLanguage ["whisper_language"] = config.WhisperLanguage,
["system_prompt"] = config.SystemPrompt
}; };
var toml = Toml.FromModel(model); var toml = Toml.FromModel(model);
File.WriteAllText(_configPath, toml); File.WriteAllText(_configPath, toml);
} }
private static HushConfig LoadFromFile(string path)
{
if (!File.Exists(path))
return new HushConfig();
try
{
var toml = File.ReadAllText(path);
var model = Toml.ToModel<TomlTable>(toml);
var config = new HushConfig();
ApplyTomlFields(model, config);
return config;
}
catch
{
return new HushConfig();
}
}
private static void ApplyTomlFields(TomlTable model, HushConfig config)
{
if (model.TryGetValue("groq_api_key", out var groqKey)) config.GroqApiKey = groqKey.ToString() ?? string.Empty;
if (model.TryGetValue("fireworks_api_key", out var fireworksKey)) config.FireworksApiKey = fireworksKey.ToString() ?? string.Empty;
if (model.TryGetValue("llm_provider", out var llmProvider)) config.LlmProvider = llmProvider.ToString() ?? "groq";
if (model.TryGetValue("whisper_provider", out var whisperProvider)) config.WhisperProvider = whisperProvider.ToString() ?? "groq";
if (model.TryGetValue("typing_backend", out var typingBackend)) config.TypingBackend = typingBackend.ToString() ?? "wtype";
if (model.TryGetValue("audio_backend", out var audioBackend)) config.AudioBackend = audioBackend.ToString() ?? "pw-record";
if (model.TryGetValue("llm_model", out var llmModel)) config.LlmModel = llmModel.ToString() ?? "openai/gpt-oss-20b";
if (model.TryGetValue("whisper_model", out var whisperModel)) config.WhisperModel = whisperModel.ToString() ?? "whisper-large-v3-turbo";
if (model.TryGetValue("min_recording_duration", out var minDuration)) config.MinRecordingDuration = Convert.ToInt32(minDuration);
if (model.TryGetValue("whisper_language", out var language)) config.WhisperLanguage = language.ToString() ?? string.Empty;
if (model.TryGetValue("system_prompt", out var systemPrompt)) config.SystemPrompt = systemPrompt.ToString() ?? string.Empty;
}
} }
+3 -3
View File
@@ -5,8 +5,6 @@ namespace Hush.Config;
public class HushConfig public class HushConfig
{ {
public string GroqApiKey { get; set; } = string.Empty; public string GroqApiKey { get; set; } = string.Empty;
public string TogetherApiKey { get; set; } = string.Empty;
public string CerebrasApiKey { get; set; } = string.Empty;
public string FireworksApiKey { get; set; } = string.Empty; public string FireworksApiKey { get; set; } = string.Empty;
public string LlmProvider { get; set; } = "groq"; public string LlmProvider { get; set; } = "groq";
@@ -16,10 +14,12 @@ public class HushConfig
public string LlmModel { get; set; } = "openai/gpt-oss-20b"; public string LlmModel { get; set; } = "openai/gpt-oss-20b";
public string WhisperModel { get; set; } = "whisper-large-v3-turbo"; public string WhisperModel { get; set; } = "whisper-large-v3-turbo";
public string ReasoningEffort { get; set; } = "none";
public int MinRecordingDuration { get; set; } = 500; public int MinRecordingDuration { get; set; } = 500;
public string WhisperLanguage { get; set; } = string.Empty; public string WhisperLanguage { get; set; } = string.Empty;
// Empty = use the built-in default transcription cleanup prompt
public string SystemPrompt { get; set; } = string.Empty;
} }
[JsonSerializable(typeof(HushConfig))] [JsonSerializable(typeof(HushConfig))]
+9
View File
@@ -1,4 +1,5 @@
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using Hush.Config;
namespace Hush.Daemon; namespace Hush.Daemon;
@@ -10,6 +11,7 @@ public static class DaemonProtocol
public const byte TOGGLE = 4; // Start if idle, stop if recording public const byte TOGGLE = 4; // Start if idle, stop if recording
public const byte STATUS = 5; // Return state as JSON public const byte STATUS = 5; // Return state as JSON
public const byte LATENCY_TEST = 6; // Run latency test, return timing JSON public const byte LATENCY_TEST = 6; // Run latency test, return timing JSON
public const byte GENERATE_PROFILE = 7; // Generate a system prompt from a description
} }
public record LatencyResult(int SttMs, int LlmMs, int TotalMs); public record LatencyResult(int SttMs, int LlmMs, int TotalMs);
@@ -18,8 +20,15 @@ public record StatusResponse(string State, long? DurationMs = null);
public record ErrorResponse(string Error); public record ErrorResponse(string Error);
public record GenerateProfileRequest(string Description);
public record GenerateProfileResponse(string SystemPrompt);
[JsonSerializable(typeof(LatencyResult))] [JsonSerializable(typeof(LatencyResult))]
[JsonSerializable(typeof(StatusResponse))] [JsonSerializable(typeof(StatusResponse))]
[JsonSerializable(typeof(ErrorResponse))] [JsonSerializable(typeof(ErrorResponse))]
[JsonSerializable(typeof(GenerateProfileRequest))]
[JsonSerializable(typeof(GenerateProfileResponse))]
[JsonSerializable(typeof(HushConfig))]
[JsonSerializable(typeof(string))] [JsonSerializable(typeof(string))]
public partial class DaemonJsonContext : JsonSerializerContext; public partial class DaemonJsonContext : JsonSerializerContext;
+114 -88
View File
@@ -1,4 +1,6 @@
using System.Net.Sockets; using System.Net.Sockets;
using System.Text;
using System.Text.Json;
using Hush.Config; using Hush.Config;
namespace Hush.Daemon; namespace Hush.Daemon;
@@ -34,10 +36,7 @@ public class DaemonService
if (File.Exists(socketPath)) if (File.Exists(socketPath))
{ {
try { File.Delete(socketPath); } try { File.Delete(socketPath); }
catch catch { /* ignored */ }
{
// ignored
}
} }
var configManager = new ConfigManager(); var configManager = new ConfigManager();
@@ -65,119 +64,103 @@ public class DaemonService
finally finally
{ {
if (File.Exists(socketPath)) if (File.Exists(socketPath))
{
File.Delete(socketPath); File.Delete(socketPath);
} }
} }
}
private static async Task HandleClientAsync(Socket client, Orchestrator orchestrator) private static async Task HandleClientAsync(Socket client, Orchestrator orchestrator)
{ {
try try
{ {
var buffer = new byte[1]; // Read command byte
var bytesRead = await client.ReceiveAsync(buffer, SocketFlags.None); var cmdBuffer = new byte[1];
var bytesRead = await client.ReceiveAsync(cmdBuffer, SocketFlags.None);
if (bytesRead == 0) { client.Close(); return; }
if (bytesRead == 0) var cmd = cmdBuffer[0];
{
client.Close();
return;
}
var cmd = buffer[0];
switch (cmd) switch (cmd)
{ {
case DaemonProtocol.START: case DaemonProtocol.START:
await HandleStartAsync(orchestrator);
break;
case DaemonProtocol.STOP: case DaemonProtocol.STOP:
await HandleStopAsync(orchestrator);
break;
case DaemonProtocol.ABORT: case DaemonProtocol.ABORT:
await HandleAbortAsync(orchestrator);
break;
case DaemonProtocol.TOGGLE: case DaemonProtocol.TOGGLE:
await HandleToggleAsync(orchestrator);
break;
case DaemonProtocol.STATUS:
await HandleStatusAsync(client, orchestrator);
break;
case DaemonProtocol.LATENCY_TEST:
await HandleLatencyTestAsync(client, orchestrator);
break;
}
}
catch (Exception ex)
{ {
Console.WriteLine($"HandleClient error: {ex.Message}"); // These commands carry an optional HushConfig payload: [4-byte LE length][JSON]
} var overrideConfig = await ReadConfigPayloadAsync(client);
finally switch (cmd)
{ {
client.Close(); case DaemonProtocol.START: await HandleStartAsync(orchestrator); break;
case DaemonProtocol.STOP: await HandleStopAsync(orchestrator, overrideConfig); break;
case DaemonProtocol.ABORT: await HandleAbortAsync(orchestrator); break;
case DaemonProtocol.TOGGLE: await HandleToggleAsync(orchestrator, overrideConfig); break;
} }
break;
}
case DaemonProtocol.STATUS: await HandleStatusAsync(client, orchestrator); break;
case DaemonProtocol.LATENCY_TEST: await HandleLatencyTestAsync(client, orchestrator); break;
case DaemonProtocol.GENERATE_PROFILE: await HandleGenerateProfileAsync(client, orchestrator); break;
}
}
catch (Exception ex) { Console.WriteLine($"HandleClient error: {ex.Message}"); }
finally { client.Close(); }
}
/// <summary>
/// Reads the optional HushConfig payload that follows action commands.
/// Format: [4-byte LE int32 length][N bytes JSON]. Returns null if length == 0.
/// </summary>
private static async Task<HushConfig?> ReadConfigPayloadAsync(Socket client)
{
var lenBuffer = new byte[4];
var totalRead = 0;
while (totalRead < 4)
{
var n = await client.ReceiveAsync(lenBuffer.AsMemory(totalRead), SocketFlags.None);
if (n == 0) return null;
totalRead += n;
}
var length = BitConverter.ToInt32(lenBuffer, 0);
if (length == 0) return null;
var jsonBuffer = new byte[length];
totalRead = 0;
while (totalRead < length)
{
var n = await client.ReceiveAsync(jsonBuffer.AsMemory(totalRead), SocketFlags.None);
if (n == 0) break;
totalRead += n;
}
return JsonSerializer.Deserialize(jsonBuffer, DaemonJsonContext.Default.HushConfig);
} }
private static async Task HandleStartAsync(Orchestrator orchestrator) private static async Task HandleStartAsync(Orchestrator orchestrator)
{ {
if (orchestrator.IsRecording) if (orchestrator.IsRecording) { Console.WriteLine("Already recording"); return; }
{ try { await orchestrator.StartRecordingAsync(); Console.WriteLine("Recording started"); }
Console.WriteLine("Already recording"); catch (Exception ex) { Console.WriteLine($"Failed to start recording: {ex.Message}"); }
return;
} }
try private static async Task HandleStopAsync(Orchestrator orchestrator, HushConfig? overrideConfig)
{ {
await orchestrator.StartRecordingAsync(); if (!orchestrator.IsRecording) { Console.WriteLine("Not recording"); return; }
Console.WriteLine("Recording started"); try { await orchestrator.StopAndProcessAsync(overrideConfig); Console.WriteLine("Recording stopped and processed"); }
} catch (Exception ex) { Console.WriteLine($"Failed to stop recording: {ex.Message}"); }
catch (Exception ex)
{
Console.WriteLine($"Failed to start recording: {ex.Message}");
}
}
private static async Task HandleStopAsync(Orchestrator orchestrator)
{
if (!orchestrator.IsRecording)
{
Console.WriteLine("Not recording");
return;
}
try
{
await orchestrator.StopAndProcessAsync();
Console.WriteLine("Recording stopped and processed");
}
catch (Exception ex)
{
Console.WriteLine($"Failed to stop recording: {ex.Message}");
}
} }
private static async Task HandleAbortAsync(Orchestrator orchestrator) private static async Task HandleAbortAsync(Orchestrator orchestrator)
{ {
if (!orchestrator.IsRecording) if (!orchestrator.IsRecording) { Console.WriteLine("Not recording"); return; }
{
Console.WriteLine("Not recording");
return;
}
await orchestrator.AbortAsync(); await orchestrator.AbortAsync();
Console.WriteLine("Recording aborted"); Console.WriteLine("Recording aborted");
} }
private static async Task HandleToggleAsync(Orchestrator orchestrator) private static async Task HandleToggleAsync(Orchestrator orchestrator, HushConfig? overrideConfig)
{ {
if (orchestrator.IsRecording) if (orchestrator.IsRecording) await HandleStopAsync(orchestrator, overrideConfig);
{ else await HandleStartAsync(orchestrator);
await HandleStopAsync(orchestrator);
}
else
{
await HandleStartAsync(orchestrator);
}
} }
private static async Task HandleStatusAsync(Socket client, Orchestrator orchestrator) private static async Task HandleStatusAsync(Socket client, Orchestrator orchestrator)
@@ -189,8 +172,8 @@ public class DaemonService
? new StatusResponse("recording", (long?)durationMs) ? new StatusResponse("recording", (long?)durationMs)
: new StatusResponse("idle"); : new StatusResponse("idle");
var json = System.Text.Json.JsonSerializer.Serialize(responseObj, DaemonJsonContext.Default.StatusResponse); var json = JsonSerializer.Serialize(responseObj, DaemonJsonContext.Default.StatusResponse);
var response = System.Text.Encoding.UTF8.GetBytes(json); var response = Encoding.UTF8.GetBytes(json);
await client.SendAsync(response, SocketFlags.None); await client.SendAsync(response, SocketFlags.None);
} }
@@ -199,15 +182,58 @@ public class DaemonService
try try
{ {
var result = await orchestrator.RunLatencyTestAsync(); var result = await orchestrator.RunLatencyTestAsync();
var json = System.Text.Json.JsonSerializer.Serialize(result, DaemonJsonContext.Default.LatencyResult); var json = JsonSerializer.Serialize(result, DaemonJsonContext.Default.LatencyResult);
var response = System.Text.Encoding.UTF8.GetBytes(json); var response = Encoding.UTF8.GetBytes(json);
await client.SendAsync(response, SocketFlags.None); await client.SendAsync(response, SocketFlags.None);
} }
catch (Exception ex) catch (Exception ex)
{ {
var error = new ErrorResponse(ex.Message); var error = new ErrorResponse(ex.Message);
var json = System.Text.Json.JsonSerializer.Serialize(error, DaemonJsonContext.Default.ErrorResponse); var json = JsonSerializer.Serialize(error, DaemonJsonContext.Default.ErrorResponse);
var response = System.Text.Encoding.UTF8.GetBytes(json); var response = Encoding.UTF8.GetBytes(json);
await client.SendAsync(response, SocketFlags.None);
}
}
private static async Task HandleGenerateProfileAsync(Socket client, Orchestrator orchestrator)
{
try
{
// Read GenerateProfileRequest payload: [4-byte LE length][JSON]
var lenBuffer = new byte[4];
var totalRead = 0;
while (totalRead < 4)
{
var n = await client.ReceiveAsync(lenBuffer.AsMemory(totalRead), SocketFlags.None);
if (n == 0) throw new InvalidOperationException("Connection closed before length prefix");
totalRead += n;
}
var length = BitConverter.ToInt32(lenBuffer, 0);
var jsonBuffer = new byte[length];
totalRead = 0;
while (totalRead < length)
{
var n = await client.ReceiveAsync(jsonBuffer.AsMemory(totalRead), SocketFlags.None);
if (n == 0) break;
totalRead += n;
}
var request = JsonSerializer.Deserialize(jsonBuffer, DaemonJsonContext.Default.GenerateProfileRequest)
?? throw new InvalidOperationException("Failed to deserialize GenerateProfileRequest");
var systemPrompt = await orchestrator.GenerateProfilePromptAsync(request.Description);
var responseObj = new GenerateProfileResponse(systemPrompt);
var json = JsonSerializer.Serialize(responseObj, DaemonJsonContext.Default.GenerateProfileResponse);
var response = Encoding.UTF8.GetBytes(json);
await client.SendAsync(response, SocketFlags.None);
}
catch (Exception ex)
{
var error = new ErrorResponse(ex.Message);
var json = JsonSerializer.Serialize(error, DaemonJsonContext.Default.ErrorResponse);
var response = Encoding.UTF8.GetBytes(json);
await client.SendAsync(response, SocketFlags.None); await client.SendAsync(response, SocketFlags.None);
} }
} }
+42 -11
View File
@@ -18,6 +18,8 @@ public class Orchestrator
private bool _isRecording; private bool _isRecording;
private readonly Lock _lock = new(); private readonly Lock _lock = new();
public Orchestrator(ConfigManager configManager) public Orchestrator(ConfigManager configManager)
{ {
_configManager = configManager; _configManager = configManager;
@@ -61,7 +63,7 @@ public class Orchestrator
return _recorder.StartRecording(_recordingPath); return _recorder.StartRecording(_recordingPath);
} }
public async Task StopAndProcessAsync() public async Task StopAndProcessAsync(HushConfig? overrideConfig = null)
{ {
string? recordingPath; string? recordingPath;
DateTime? recordingStartTime; DateTime? recordingStartTime;
@@ -86,7 +88,7 @@ public class Orchestrator
try try
{ {
var config = _configManager.Load(); var config = overrideConfig ?? _configManager.Load();
var recordingDuration = recordingStartTime.HasValue var recordingDuration = recordingStartTime.HasValue
? DateTime.UtcNow - recordingStartTime.Value ? DateTime.UtcNow - recordingStartTime.Value
@@ -140,14 +142,14 @@ public class Orchestrator
var provider = GetAudioToTextProvider(config); var provider = GetAudioToTextProvider(config);
await using var stream = File.OpenRead(path); await using var stream = File.OpenRead(path);
return await provider.TranscribeAsync(stream, config.WhisperModel); return await provider.TranscribeAsync(
stream,
config.WhisperModel,
language: string.IsNullOrEmpty(config.WhisperLanguage) ? null : config.WhisperLanguage);
} }
private async Task<string> ProcessWithLlmAsync(string text, HushConfig config) private const string DefaultSystemPrompt =
{ """
var provider = GetTextProvider(config);
var prompt = $"""
You are a transcription post-processor. Your task is to clean up raw speech-to-text output and return polished, ready-to-type text. You are a transcription post-processor. Your task is to clean up raw speech-to-text output and return polished, ready-to-type text.
Rules: Rules:
@@ -159,11 +161,40 @@ public class Orchestrator
- Do not add, remove, or reinterpret content beyond what was said - Do not add, remove, or reinterpret content beyond what was said
- Do not include any explanation, preamble, or metadata output only the corrected text - Do not include any explanation, preamble, or metadata output only the corrected text
- If the input is empty or unintelligible, return an empty string - If the input is empty or unintelligible, return an empty string
Raw transcription: {text}
"""; """;
return await provider.CompleteTextAsync(prompt, config.LlmModel); private async Task<string> ProcessWithLlmAsync(string text, HushConfig config)
{
var provider = GetTextProvider(config);
var systemPrompt = string.IsNullOrWhiteSpace(config.SystemPrompt)
? DefaultSystemPrompt
: config.SystemPrompt;
return await provider.CompleteTextAsync(systemPrompt, text, config.LlmModel);
}
public async Task<string> GenerateProfilePromptAsync(string description)
{
var config = _configManager.Load();
var provider = GetTextProvider(config);
const string systemPrompt =
"""
You are a configuration assistant for Hush, a Linux speech-to-text post-processor.
Hush records the user's voice, transcribes it with Whisper, then passes the transcription
to an LLM using a system prompt you will write.
Given the user's description of what they want the profile to do, write a precise, concise
system prompt that instructs the LLM how to transform the raw transcription.
Rules:
- Output only the system prompt text, nothing else
- Do not include meta-commentary, labels, or markdown formatting
- The prompt must be self-contained and unambiguous
- Always end with an instruction to output only the final result with no explanation
""";
return await provider.CompleteTextAsync(systemPrompt, description, config.LlmModel);
} }
private async Task TypeAsync(string text, HushConfig config) private async Task TypeAsync(string text, HushConfig config)
@@ -15,5 +15,6 @@ public interface IAudioToTextProvider
Task<string> TranscribeAsync( Task<string> TranscribeAsync(
Stream audioStream, Stream audioStream,
string modelName, string modelName,
string? language = null,
CancellationToken cancellationToken = default); CancellationToken cancellationToken = default);
} }
@@ -1,31 +1,21 @@
namespace Hush.Providers.Interfaces; namespace Hush.Providers.Interfaces;
/// <summary> /// <summary>
/// Interface for text generation with both synchronous and streaming capabilities. /// Interface for text generation.
/// </summary> /// </summary>
public interface ITextStreamingProvider public interface ITextStreamingProvider
{ {
/// <summary> /// <summary>
/// Generates text completion for a given prompt. /// Generates a text completion using a system prompt and a user message.
/// </summary> /// </summary>
/// <param name="prompt">The input prompt</param> /// <param name="systemPrompt">The system prompt that instructs the model how to behave</param>
/// <param name="userMessage">The user message to process</param>
/// <param name="modelName">The model name to use (e.g., llama-3.3-70b-versatile)</param> /// <param name="modelName">The model name to use (e.g., llama-3.3-70b-versatile)</param>
/// <param name="cancellationToken">Cancellation token</param> /// <param name="cancellationToken">Cancellation token</param>
/// <returns>The generated text</returns> /// <returns>The generated text</returns>
Task<string> CompleteTextAsync( Task<string> CompleteTextAsync(
string prompt, string systemPrompt,
string modelName, string userMessage,
CancellationToken cancellationToken = default);
/// <summary>
/// Streams text generation for a given prompt.
/// </summary>
/// <param name="prompt">The input prompt</param>
/// <param name="modelName">The model name to use (e.g., llama-3.3-70b-versatile)</param>
/// <param name="cancellationToken">Cancellation token</param>
/// <returns>Async enumerable of text chunks</returns>
IAsyncEnumerable<string> StreamTextAsync(
string prompt,
string modelName, string modelName,
CancellationToken cancellationToken = default); CancellationToken cancellationToken = default);
} }
@@ -34,6 +34,7 @@ public class FireworksProvider : IAudioToTextProvider, ITextStreamingProvider
public async Task<string> TranscribeAsync( public async Task<string> TranscribeAsync(
Stream audioStream, Stream audioStream,
string modelName, string modelName,
string? language = null,
CancellationToken cancellationToken = default) CancellationToken cancellationToken = default)
{ {
ArgumentNullException.ThrowIfNull(audioStream); ArgumentNullException.ThrowIfNull(audioStream);
@@ -45,7 +46,7 @@ public class FireworksProvider : IAudioToTextProvider, ITextStreamingProvider
? TRANSCRIPTION_ENDPOINT_TURBO ? TRANSCRIPTION_ENDPOINT_TURBO
: TRANSCRIPTION_ENDPOINT_PROD; : TRANSCRIPTION_ENDPOINT_PROD;
var request = new TranscriptionRequest { Model = modelName }; var request = new TranscriptionRequest { Model = modelName, Language = language };
using var content = new MultipartFormDataContent(); using var content = new MultipartFormDataContent();
content.Add(new StreamContent(audioStream), "file", "audio.wav"); content.Add(new StreamContent(audioStream), "file", "audio.wav");
@@ -84,12 +85,13 @@ public class FireworksProvider : IAudioToTextProvider, ITextStreamingProvider
/// <inheritdoc /> /// <inheritdoc />
public async Task<string> CompleteTextAsync( public async Task<string> CompleteTextAsync(
string prompt, string systemPrompt,
string userMessage,
string modelName, string modelName,
CancellationToken cancellationToken = default) CancellationToken cancellationToken = default)
{ {
if (string.IsNullOrWhiteSpace(prompt)) if (string.IsNullOrWhiteSpace(systemPrompt))
throw new ArgumentException("Prompt is required", nameof(prompt)); throw new ArgumentException("System prompt is required", nameof(systemPrompt));
if (string.IsNullOrWhiteSpace(modelName)) if (string.IsNullOrWhiteSpace(modelName))
throw new ArgumentException("Model name is required", nameof(modelName)); throw new ArgumentException("Model name is required", nameof(modelName));
@@ -97,7 +99,11 @@ public class FireworksProvider : IAudioToTextProvider, ITextStreamingProvider
var request = new ChatCompletionRequest var request = new ChatCompletionRequest
{ {
Model = modelName, Model = modelName,
Messages = new List<Message> { new() { Role = "user", Content = prompt } } Messages = new List<Message>
{
new() { Role = "system", Content = systemPrompt },
new() { Role = "user", Content = userMessage }
}
}; };
var jsonContent = new StringContent( var jsonContent = new StringContent(
@@ -126,86 +132,5 @@ public class FireworksProvider : IAudioToTextProvider, ITextStreamingProvider
return result.Choices[0].Message.Content; return result.Choices[0].Message.Content;
} }
/// <inheritdoc />
public async IAsyncEnumerable<string> StreamTextAsync(
string prompt,
string modelName,
[System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (string.IsNullOrWhiteSpace(prompt))
throw new ArgumentException("Prompt is required", nameof(prompt));
if (string.IsNullOrWhiteSpace(modelName))
throw new ArgumentException("Model name is required", nameof(modelName));
var request = new ChatCompletionRequest
{
Model = modelName,
Stream = true,
Messages = new List<Message> { new() { Role = "user", Content = prompt } }
};
var jsonContent = new StringContent(
JsonSerializer.Serialize(request, JsonSourceGeneration.Default.ChatCompletionRequest),
Encoding.UTF8,
"application/json");
var httpRequest = new HttpRequestMessage(HttpMethod.Post, CHAT_COMPLETION_ENDPOINT)
{
Content = jsonContent
};
httpRequest.Headers.TryAddWithoutValidation("Authorization", _apiKey);
using var response = await _httpClient.SendAsync(httpRequest, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false);
response.EnsureSuccessStatusCode();
using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
using var reader = new StreamReader(stream);
string? line;
while ((line = await reader.ReadLineAsync(cancellationToken).ConfigureAwait(false)) != null)
{
if (string.IsNullOrWhiteSpace(line) || !line.StartsWith("data: "))
continue;
var data = line.Substring(6).Trim(); // Remove "data: " prefix
if (data == "[DONE]")
break;
var text = ParseTextFromStreamData(data);
if (!string.IsNullOrEmpty(text))
yield return text;
}
}
private static string? ParseTextFromStreamData(string data)
{
try
{
using var jsonDoc = JsonDocument.Parse(data);
var choices = jsonDoc.RootElement.GetProperty("choices");
var choice = choices[0];
if (choice.TryGetProperty("delta", out var delta))
{
if (delta.TryGetProperty("content", out var content))
{
return content.GetString();
}
}
else if (choice.TryGetProperty("text", out var text))
{
return text.GetString();
}
}
catch (JsonException)
{
// Skip malformed JSON chunks
}
return null;
}
} }
+11 -88
View File
@@ -33,6 +33,7 @@ public class GroqProvider : IAudioToTextProvider, ITextStreamingProvider
public async Task<string> TranscribeAsync( public async Task<string> TranscribeAsync(
Stream audioStream, Stream audioStream,
string modelName, string modelName,
string? language = null,
CancellationToken cancellationToken = default) CancellationToken cancellationToken = default)
{ {
ArgumentNullException.ThrowIfNull(audioStream); ArgumentNullException.ThrowIfNull(audioStream);
@@ -40,7 +41,7 @@ public class GroqProvider : IAudioToTextProvider, ITextStreamingProvider
if (string.IsNullOrWhiteSpace(modelName)) if (string.IsNullOrWhiteSpace(modelName))
throw new ArgumentException("Model name is required", nameof(modelName)); throw new ArgumentException("Model name is required", nameof(modelName));
var request = new TranscriptionRequest { Model = modelName }; var request = new TranscriptionRequest { Model = modelName, Language = language };
using var content = new MultipartFormDataContent(); using var content = new MultipartFormDataContent();
content.Add(new StreamContent(audioStream), "file", "audio.wav"); content.Add(new StreamContent(audioStream), "file", "audio.wav");
@@ -79,20 +80,24 @@ public class GroqProvider : IAudioToTextProvider, ITextStreamingProvider
/// <inheritdoc /> /// <inheritdoc />
public async Task<string> CompleteTextAsync( public async Task<string> CompleteTextAsync(
string prompt, string systemPrompt,
string userMessage,
string modelName, string modelName,
CancellationToken cancellationToken = default) CancellationToken cancellationToken = default)
{ {
if (string.IsNullOrWhiteSpace(prompt)) if (string.IsNullOrWhiteSpace(systemPrompt))
throw new ArgumentException("Prompt is required", nameof(prompt)); throw new ArgumentException("System prompt is required", nameof(systemPrompt));
if (string.IsNullOrWhiteSpace(modelName)) if (string.IsNullOrWhiteSpace(modelName))
throw new ArgumentException("Model name is required", nameof(modelName)); throw new ArgumentException("Model name is required", nameof(modelName));
var request = new ChatCompletionRequest var request = new ChatCompletionRequest
{ {
Model = modelName, Model = modelName,
Messages = new List<Models.Request.Message> { new() { Role = "user", Content = prompt } } Messages = new List<Models.Request.Message>
{
new() { Role = "system", Content = systemPrompt },
new() { Role = "user", Content = userMessage }
}
}; };
var jsonContent = new StringContent( var jsonContent = new StringContent(
@@ -121,86 +126,4 @@ public class GroqProvider : IAudioToTextProvider, ITextStreamingProvider
return result.Choices[0].Message.Content; return result.Choices[0].Message.Content;
} }
/// <inheritdoc />
public async IAsyncEnumerable<string> StreamTextAsync(
string prompt,
string modelName,
[System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (string.IsNullOrWhiteSpace(prompt))
throw new ArgumentException("Prompt is required", nameof(prompt));
if (string.IsNullOrWhiteSpace(modelName))
throw new ArgumentException("Model name is required", nameof(modelName));
var request = new ChatCompletionRequest
{
Model = modelName,
Stream = true,
Messages = new List<Models.Request.Message> { new() { Role = "user", Content = prompt } }
};
var jsonContent = new StringContent(
JsonSerializer.Serialize(request, JsonSourceGeneration.Default.ChatCompletionRequest),
Encoding.UTF8,
"application/json");
var httpRequest = new HttpRequestMessage(HttpMethod.Post, CHAT_COMPLETION_ENDPOINT)
{
Content = jsonContent
};
httpRequest.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _apiKey);
using var response = await _httpClient.SendAsync(httpRequest, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false);
response.EnsureSuccessStatusCode();
using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
using var reader = new StreamReader(stream);
string? line;
while ((line = await reader.ReadLineAsync(cancellationToken).ConfigureAwait(false)) != null)
{
if (string.IsNullOrWhiteSpace(line) || !line.StartsWith("data: "))
continue;
var data = line.Substring(6).Trim(); // Remove "data: " prefix
if (data == "[DONE]")
break;
var text = ParseTextFromStreamData(data);
if (!string.IsNullOrEmpty(text))
yield return text;
}
}
private static string? ParseTextFromStreamData(string data)
{
try
{
using var jsonDoc = JsonDocument.Parse(data);
var choices = jsonDoc.RootElement.GetProperty("choices");
var choice = choices[0];
if (choice.TryGetProperty("delta", out var delta))
{
if (delta.TryGetProperty("content", out var content))
{
return content.GetString();
}
}
else if (choice.TryGetProperty("text", out var text))
{
return text.GetString();
}
}
catch (JsonException)
{
// Skip malformed JSON chunks
}
return null;
}
} }