add Fireworks AI provider support
This commit is contained in:
@@ -32,48 +32,62 @@ public static class SetupCommand
|
|||||||
AnsiConsole.MarkupLine("[bold blue]Welcome to Hush Setup![/]\n");
|
AnsiConsole.MarkupLine("[bold blue]Welcome to Hush Setup![/]\n");
|
||||||
AnsiConsole.MarkupLine("This wizard will help you configure Hush.\n");
|
AnsiConsole.MarkupLine("This wizard will help you configure Hush.\n");
|
||||||
|
|
||||||
var config = new HushConfig();
|
var configManager = new ConfigManager();
|
||||||
|
var config = configManager.Load();
|
||||||
|
|
||||||
AnsiConsole.MarkupLine("[bold]Step 1: Whisper Provider[/]");
|
AnsiConsole.MarkupLine("[bold]Step 1: Whisper Provider[/]");
|
||||||
config.WhisperProvider = AnsiConsole.Prompt(
|
config.WhisperProvider = PromptSelection("Select Whisper provider:", ["groq", "fireworks"], config.WhisperProvider);
|
||||||
new SelectionPrompt<string>()
|
|
||||||
.Title("Select Whisper provider:")
|
|
||||||
.AddChoices("groq"));
|
|
||||||
|
|
||||||
AnsiConsole.MarkupLine("[bold]Step 2: Groq API Key[/]");
|
AnsiConsole.MarkupLine("[bold]Step 2: LLM Provider[/]");
|
||||||
config.GroqApiKey = ReadMaskedInput("Enter your Groq API key:");
|
config.LlmProvider = PromptSelection("Select LLM provider:", ["groq", "fireworks"], config.LlmProvider);
|
||||||
|
|
||||||
AnsiConsole.MarkupLine("[bold]Step 3: LLM Model[/]");
|
if (config.WhisperProvider == "groq" || config.LlmProvider == "groq")
|
||||||
config.LlmProvider = "groq";
|
{
|
||||||
config.LlmModel = AnsiConsole.Prompt(
|
AnsiConsole.MarkupLine("[bold]Step 3: Groq API Key[/]");
|
||||||
new SelectionPrompt<string>()
|
config.GroqApiKey = ReadMaskedInput("Enter your Groq API key:", config.GroqApiKey);
|
||||||
.Title("Select LLM model:")
|
}
|
||||||
.AddChoices("openai/gpt-oss-20b", "llama-3.1-8b-instant", "openai/gpt-oss-120b"));
|
|
||||||
|
|
||||||
AnsiConsole.MarkupLine("[bold]Step 4: Whisper Model[/]");
|
if (config.WhisperProvider == "fireworks" || config.LlmProvider == "fireworks")
|
||||||
config.WhisperModel = AnsiConsole.Prompt(
|
{
|
||||||
new SelectionPrompt<string>()
|
AnsiConsole.MarkupLine("[bold]Step 3: Fireworks API Key[/]");
|
||||||
.Title("Select Whisper model:")
|
config.FireworksApiKey = ReadMaskedInput("Enter your Fireworks API key:", config.FireworksApiKey);
|
||||||
.AddChoices("whisper-large-v3", "whisper-large-v3-turbo"));
|
}
|
||||||
|
|
||||||
AnsiConsole.MarkupLine("[bold]Step 5: Audio Backend[/]");
|
AnsiConsole.MarkupLine("[bold]Step 4: LLM Model[/]");
|
||||||
config.AudioBackend = AnsiConsole.Prompt(
|
config.LlmModel = config.LlmProvider switch
|
||||||
new SelectionPrompt<string>()
|
{
|
||||||
.Title("Select audio backend:")
|
"fireworks" => PromptSelection("Select LLM model:", [
|
||||||
.AddChoices("pipewire", "ffmpeg"));
|
"accounts/fireworks/models/kimi-k2-instruct-0905",
|
||||||
|
"accounts/fireworks/models/deepseek-v3p1",
|
||||||
|
"accounts/fireworks/models/llama-v3p1-70b-instruct"
|
||||||
|
], config.LlmModel),
|
||||||
|
_ => PromptSelection("Select LLM model:", [
|
||||||
|
"openai/gpt-oss-20b", "llama-3.1-8b-instant", "openai/gpt-oss-120b"
|
||||||
|
], config.LlmModel)
|
||||||
|
};
|
||||||
|
|
||||||
AnsiConsole.MarkupLine("[bold]Step 6: Typing Backend[/]");
|
AnsiConsole.MarkupLine("[bold]Step 5: Whisper Model[/]");
|
||||||
config.TypingBackend = AnsiConsole.Prompt(
|
config.WhisperModel = config.WhisperProvider switch
|
||||||
new SelectionPrompt<string>()
|
{
|
||||||
.Title("Select typing backend:")
|
"fireworks" => PromptSelection("Select Whisper model:", [
|
||||||
.AddChoices("wtype", "xdotool"));
|
"whisper-v3", "whisper-v3-turbo"
|
||||||
|
], config.WhisperModel),
|
||||||
|
_ => PromptSelection("Select Whisper model:", [
|
||||||
|
"whisper-large-v3", "whisper-large-v3-turbo"
|
||||||
|
], config.WhisperModel)
|
||||||
|
};
|
||||||
|
|
||||||
AnsiConsole.MarkupLine("[bold]Step 7: Minimum Recording Duration[/]");
|
AnsiConsole.MarkupLine("[bold]Step 6: Audio Backend[/]");
|
||||||
var minDuration = AnsiConsole.Prompt(
|
config.AudioBackend = PromptSelection("Select audio backend:", ["pipewire", "ffmpeg"], config.AudioBackend);
|
||||||
new TextPrompt<int>("Enter minimum duration (ms, default 500):")
|
|
||||||
.DefaultValue(500)
|
AnsiConsole.MarkupLine("[bold]Step 7: Typing Backend[/]");
|
||||||
|
config.TypingBackend = PromptSelection("Select typing backend:", ["wtype", "xdotool"], config.TypingBackend);
|
||||||
|
|
||||||
|
AnsiConsole.MarkupLine("[bold]Step 8: Minimum Recording Duration[/]");
|
||||||
|
config.MinRecordingDuration = AnsiConsole.Prompt(
|
||||||
|
new TextPrompt<int>("Enter minimum duration (ms):")
|
||||||
|
.DefaultValue(config.MinRecordingDuration)
|
||||||
.Validate(x => x > 0, "Must be greater than 0"));
|
.Validate(x => x > 0, "Must be greater than 0"));
|
||||||
config.MinRecordingDuration = minDuration;
|
|
||||||
|
|
||||||
AnsiConsole.WriteLine();
|
AnsiConsole.WriteLine();
|
||||||
AnsiConsole.MarkupLine("[bold]Configuration Summary:[/]");
|
AnsiConsole.MarkupLine("[bold]Configuration Summary:[/]");
|
||||||
@@ -82,7 +96,11 @@ public static class SetupCommand
|
|||||||
table.AddColumn("Setting");
|
table.AddColumn("Setting");
|
||||||
table.AddColumn("Value");
|
table.AddColumn("Value");
|
||||||
table.AddRow("Whisper Provider", config.WhisperProvider);
|
table.AddRow("Whisper Provider", config.WhisperProvider);
|
||||||
table.AddRow("Groq API Key", MaskApiKey(config.GroqApiKey));
|
table.AddRow("LLM Provider", config.LlmProvider);
|
||||||
|
if (!string.IsNullOrEmpty(config.GroqApiKey))
|
||||||
|
table.AddRow("Groq API Key", MaskApiKey(config.GroqApiKey));
|
||||||
|
if (!string.IsNullOrEmpty(config.FireworksApiKey))
|
||||||
|
table.AddRow("Fireworks API Key", MaskApiKey(config.FireworksApiKey));
|
||||||
table.AddRow("LLM Model", config.LlmModel);
|
table.AddRow("LLM Model", config.LlmModel);
|
||||||
table.AddRow("Whisper Model", config.WhisperModel);
|
table.AddRow("Whisper Model", config.WhisperModel);
|
||||||
table.AddRow("Audio Backend", config.AudioBackend);
|
table.AddRow("Audio Backend", config.AudioBackend);
|
||||||
@@ -94,9 +112,20 @@ public static class SetupCommand
|
|||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static string ReadMaskedInput(string prompt)
|
private static string PromptSelection(string title, string[] choices, string current)
|
||||||
{
|
{
|
||||||
AnsiConsole.Write(prompt + " ");
|
var prompt = new SelectionPrompt<string>().Title(title);
|
||||||
|
// Put current value first so it starts highlighted; add it if it's a custom value not in the list
|
||||||
|
prompt.AddChoice(current);
|
||||||
|
foreach (var choice in choices.Where(c => c != current))
|
||||||
|
prompt.AddChoice(choice);
|
||||||
|
return AnsiConsole.Prompt(prompt);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static string ReadMaskedInput(string prompt, string current = "")
|
||||||
|
{
|
||||||
|
var hint = string.IsNullOrEmpty(current) ? "" : $" [{MaskApiKey(current)}]";
|
||||||
|
AnsiConsole.Write(prompt + hint + " ");
|
||||||
var input = "";
|
var input = "";
|
||||||
while (true)
|
while (true)
|
||||||
{
|
{
|
||||||
@@ -117,7 +146,8 @@ public static class SetupCommand
|
|||||||
AnsiConsole.Write("*");
|
AnsiConsole.Write("*");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return input;
|
// Keep existing key if user just hit Enter without typing anything
|
||||||
|
return string.IsNullOrEmpty(input) ? current : input;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static string MaskApiKey(string key)
|
private static string MaskApiKey(string key)
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ public static class ShowCommand
|
|||||||
{
|
{
|
||||||
AnsiConsole.WriteLine($"whisper_provider={config.WhisperProvider}");
|
AnsiConsole.WriteLine($"whisper_provider={config.WhisperProvider}");
|
||||||
AnsiConsole.WriteLine($"groq_api_key={MaskApiKey(config.GroqApiKey)}");
|
AnsiConsole.WriteLine($"groq_api_key={MaskApiKey(config.GroqApiKey)}");
|
||||||
|
AnsiConsole.WriteLine($"fireworks_api_key={MaskApiKey(config.FireworksApiKey)}");
|
||||||
AnsiConsole.WriteLine($"llm_provider={config.LlmProvider}");
|
AnsiConsole.WriteLine($"llm_provider={config.LlmProvider}");
|
||||||
AnsiConsole.WriteLine($"llm_model={config.LlmModel}");
|
AnsiConsole.WriteLine($"llm_model={config.LlmModel}");
|
||||||
AnsiConsole.WriteLine($"whisper_model={config.WhisperModel}");
|
AnsiConsole.WriteLine($"whisper_model={config.WhisperModel}");
|
||||||
@@ -50,8 +51,9 @@ public static class ShowCommand
|
|||||||
table.AddColumn("Value");
|
table.AddColumn("Value");
|
||||||
|
|
||||||
table.AddRow("Whisper Provider", config.WhisperProvider);
|
table.AddRow("Whisper Provider", config.WhisperProvider);
|
||||||
table.AddRow("Groq API Key", MaskApiKey(config.GroqApiKey));
|
|
||||||
table.AddRow("LLM Provider", config.LlmProvider);
|
table.AddRow("LLM Provider", config.LlmProvider);
|
||||||
|
table.AddRow("Groq API Key", MaskApiKey(config.GroqApiKey));
|
||||||
|
table.AddRow("Fireworks API Key", MaskApiKey(config.FireworksApiKey));
|
||||||
table.AddRow("LLM Model", config.LlmModel);
|
table.AddRow("LLM Model", config.LlmModel);
|
||||||
table.AddRow("Whisper Model", config.WhisperModel);
|
table.AddRow("Whisper Model", config.WhisperModel);
|
||||||
table.AddRow("Audio Backend", config.AudioBackend);
|
table.AddRow("Audio Backend", config.AudioBackend);
|
||||||
|
|||||||
@@ -8,11 +8,10 @@ namespace Hush.Daemon;
|
|||||||
|
|
||||||
public class Orchestrator
|
public class Orchestrator
|
||||||
{
|
{
|
||||||
|
private static readonly HttpClient _httpClient = new();
|
||||||
|
|
||||||
private readonly ConfigManager _configManager;
|
private readonly ConfigManager _configManager;
|
||||||
private readonly IAudioRecorder _recorder;
|
private readonly IAudioRecorder _recorder;
|
||||||
private IAudioToTextProvider? _audioToTextProvider;
|
|
||||||
private ITextStreamingProvider? _textProvider;
|
|
||||||
private ITextInput? _textInput;
|
|
||||||
|
|
||||||
private string? _recordingPath;
|
private string? _recordingPath;
|
||||||
private DateTime? _recordingStartTime;
|
private DateTime? _recordingStartTime;
|
||||||
@@ -149,9 +148,19 @@ public class Orchestrator
|
|||||||
var provider = GetTextProvider(config);
|
var provider = GetTextProvider(config);
|
||||||
|
|
||||||
var prompt = $"""
|
var prompt = $"""
|
||||||
Process this spoken text for clarity and correctness. Fix any errors, add proper punctuation, and make it read naturally. Keep the original meaning intact.
|
You are a transcription post-processor. Your task is to clean up raw speech-to-text output and return polished, ready-to-type text.
|
||||||
|
|
||||||
Text: {text}
|
Rules:
|
||||||
|
- Detect the language of the transcription and process it entirely in that language — do not translate
|
||||||
|
- Fix grammar, spelling, and punctuation errors introduced by the speech recognizer, following the conventions of the detected language
|
||||||
|
- Capitalize sentences and proper nouns appropriately for the detected language
|
||||||
|
- Remove filler words and false starts appropriate to the detected language (e.g. "um", "uh", "like" in English; "euh", "bah" in French; "äh", "ähm" in German; "eh", "tipo" in Spanish/Italian)
|
||||||
|
- Preserve the speaker's original intent, vocabulary choices, and tone
|
||||||
|
- Do not add, remove, or reinterpret content beyond what was said
|
||||||
|
- Do not include any explanation, preamble, or metadata — output only the corrected text
|
||||||
|
- If the input is empty or unintelligible, return an empty string
|
||||||
|
|
||||||
|
Raw transcription: {text}
|
||||||
""";
|
""";
|
||||||
|
|
||||||
return await provider.CompleteTextAsync(prompt, config.LlmModel);
|
return await provider.CompleteTextAsync(prompt, config.LlmModel);
|
||||||
@@ -163,51 +172,36 @@ public class Orchestrator
|
|||||||
await input.TypeString(text);
|
await input.TypeString(text);
|
||||||
}
|
}
|
||||||
|
|
||||||
private IAudioToTextProvider GetAudioToTextProvider(HushConfig config)
|
private IAudioToTextProvider GetAudioToTextProvider(HushConfig config) =>
|
||||||
{
|
config.WhisperProvider switch
|
||||||
if (_audioToTextProvider != null)
|
|
||||||
return _audioToTextProvider;
|
|
||||||
|
|
||||||
_audioToTextProvider = config.WhisperProvider switch
|
|
||||||
{
|
{
|
||||||
"groq" => string.IsNullOrEmpty(config.GroqApiKey)
|
"groq" => string.IsNullOrEmpty(config.GroqApiKey)
|
||||||
? throw new InvalidOperationException("Groq API key is required for Whisper transcription")
|
? throw new InvalidOperationException("Groq API key is required for Whisper transcription")
|
||||||
: new GroqProvider(config.GroqApiKey),
|
: new GroqProvider(config.GroqApiKey, _httpClient),
|
||||||
|
"fireworks" => string.IsNullOrEmpty(config.FireworksApiKey)
|
||||||
|
? throw new InvalidOperationException("Fireworks API key is required for Whisper transcription")
|
||||||
|
: new FireworksProvider(config.FireworksApiKey, _httpClient),
|
||||||
_ => throw new InvalidOperationException($"Unsupported Whisper provider: {config.WhisperProvider}")
|
_ => throw new InvalidOperationException($"Unsupported Whisper provider: {config.WhisperProvider}")
|
||||||
};
|
};
|
||||||
|
|
||||||
return _audioToTextProvider;
|
|
||||||
}
|
|
||||||
|
|
||||||
private ITextStreamingProvider GetTextProvider(HushConfig config)
|
private ITextStreamingProvider GetTextProvider(HushConfig config) =>
|
||||||
{
|
config.LlmProvider switch
|
||||||
if (_textProvider != null)
|
|
||||||
return _textProvider;
|
|
||||||
|
|
||||||
_textProvider = config.LlmProvider switch
|
|
||||||
{
|
{
|
||||||
"groq" => string.IsNullOrEmpty(config.GroqApiKey)
|
"groq" => string.IsNullOrEmpty(config.GroqApiKey)
|
||||||
? throw new InvalidOperationException("Groq API key is required for LLM")
|
? throw new InvalidOperationException("Groq API key is required for LLM")
|
||||||
: new GroqProvider(config.GroqApiKey),
|
: new GroqProvider(config.GroqApiKey, _httpClient),
|
||||||
|
"fireworks" => string.IsNullOrEmpty(config.FireworksApiKey)
|
||||||
|
? throw new InvalidOperationException("Fireworks API key is required for LLM")
|
||||||
|
: new FireworksProvider(config.FireworksApiKey, _httpClient),
|
||||||
_ => throw new InvalidOperationException($"Unsupported LLM provider: {config.LlmProvider}")
|
_ => throw new InvalidOperationException($"Unsupported LLM provider: {config.LlmProvider}")
|
||||||
};
|
};
|
||||||
|
|
||||||
return _textProvider;
|
|
||||||
}
|
|
||||||
|
|
||||||
private ITextInput GetTextInput(HushConfig config)
|
private static ITextInput GetTextInput(HushConfig config) =>
|
||||||
{
|
config.TypingBackend switch
|
||||||
if (_textInput != null)
|
|
||||||
return _textInput;
|
|
||||||
|
|
||||||
_textInput = config.TypingBackend switch
|
|
||||||
{
|
{
|
||||||
"xdotool" => new XdotoolInput(),
|
"xdotool" => new XdotoolInput(),
|
||||||
_ => new WtypeInput()
|
_ => new WtypeInput()
|
||||||
};
|
};
|
||||||
|
|
||||||
return _textInput;
|
|
||||||
}
|
|
||||||
|
|
||||||
private IAudioRecorder CreateAudioRecorder()
|
private IAudioRecorder CreateAudioRecorder()
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -0,0 +1,211 @@
|
|||||||
|
using System.Net.Http.Headers;
|
||||||
|
using System.Text;
|
||||||
|
using System.Text.Json;
|
||||||
|
using Hush.Providers.Interfaces;
|
||||||
|
using Hush.Providers.Models.Request;
|
||||||
|
using Hush.Providers.Serialization;
|
||||||
|
|
||||||
|
namespace Hush.Providers.Providers;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Implementation of LLM provider for Fireworks AI API.
|
||||||
|
/// </summary>
|
||||||
|
public class FireworksProvider : IAudioToTextProvider, ITextStreamingProvider
|
||||||
|
{
|
||||||
|
private const string CHAT_COMPLETION_ENDPOINT = "https://api.fireworks.ai/inference/v1/chat/completions";
|
||||||
|
private const string TRANSCRIPTION_ENDPOINT_PROD = "https://audio-prod.api.fireworks.ai/v1/audio/transcriptions";
|
||||||
|
private const string TRANSCRIPTION_ENDPOINT_TURBO = "https://audio-turbo.api.fireworks.ai/v1/audio/transcriptions";
|
||||||
|
|
||||||
|
private readonly HttpClient _httpClient;
|
||||||
|
private readonly string _apiKey;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Initializes a new instance of the FireworksProvider class.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="apiKey">The Fireworks AI API key</param>
|
||||||
|
/// <param name="httpClient">Optional HttpClient instance (for testing)</param>
|
||||||
|
public FireworksProvider(string apiKey, HttpClient? httpClient = null)
|
||||||
|
{
|
||||||
|
_apiKey = apiKey ?? throw new ArgumentNullException(nameof(apiKey));
|
||||||
|
_httpClient = httpClient ?? new HttpClient();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <inheritdoc />
|
||||||
|
public async Task<string> TranscribeAsync(
|
||||||
|
Stream audioStream,
|
||||||
|
string modelName,
|
||||||
|
CancellationToken cancellationToken = default)
|
||||||
|
{
|
||||||
|
ArgumentNullException.ThrowIfNull(audioStream);
|
||||||
|
|
||||||
|
if (string.IsNullOrWhiteSpace(modelName))
|
||||||
|
throw new ArgumentException("Model name is required", nameof(modelName));
|
||||||
|
|
||||||
|
var endpoint = modelName.Contains("turbo", StringComparison.OrdinalIgnoreCase)
|
||||||
|
? TRANSCRIPTION_ENDPOINT_TURBO
|
||||||
|
: TRANSCRIPTION_ENDPOINT_PROD;
|
||||||
|
|
||||||
|
var request = new TranscriptionRequest { Model = modelName };
|
||||||
|
|
||||||
|
using var content = new MultipartFormDataContent();
|
||||||
|
content.Add(new StreamContent(audioStream), "file", "audio.wav");
|
||||||
|
content.Add(new StringContent(request.Model), "model");
|
||||||
|
if (request.ResponseFormat != null)
|
||||||
|
content.Add(new StringContent(request.ResponseFormat), "response_format");
|
||||||
|
if (request.Language != null)
|
||||||
|
content.Add(new StringContent(request.Language), "language");
|
||||||
|
if (request.Prompt != null)
|
||||||
|
content.Add(new StringContent(request.Prompt), "prompt");
|
||||||
|
if (request.Temperature.HasValue)
|
||||||
|
content.Add(new StringContent(request.Temperature.Value.ToString(System.Globalization.CultureInfo.InvariantCulture)), "temperature");
|
||||||
|
|
||||||
|
var httpRequest = new HttpRequestMessage(HttpMethod.Post, endpoint)
|
||||||
|
{
|
||||||
|
Content = content
|
||||||
|
};
|
||||||
|
|
||||||
|
httpRequest.Headers.TryAddWithoutValidation("Authorization", _apiKey);
|
||||||
|
|
||||||
|
using var response = await _httpClient.SendAsync(httpRequest, cancellationToken).ConfigureAwait(false);
|
||||||
|
|
||||||
|
response.EnsureSuccessStatusCode();
|
||||||
|
|
||||||
|
var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);
|
||||||
|
|
||||||
|
var result = JsonSerializer.Deserialize(
|
||||||
|
responseContent,
|
||||||
|
JsonSourceGeneration.Default.TranscriptionResponse);
|
||||||
|
|
||||||
|
if (result == null)
|
||||||
|
throw new InvalidOperationException("Failed to deserialize transcription response");
|
||||||
|
|
||||||
|
return result.Text;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <inheritdoc />
|
||||||
|
public async Task<string> CompleteTextAsync(
|
||||||
|
string prompt,
|
||||||
|
string modelName,
|
||||||
|
CancellationToken cancellationToken = default)
|
||||||
|
{
|
||||||
|
if (string.IsNullOrWhiteSpace(prompt))
|
||||||
|
throw new ArgumentException("Prompt is required", nameof(prompt));
|
||||||
|
|
||||||
|
if (string.IsNullOrWhiteSpace(modelName))
|
||||||
|
throw new ArgumentException("Model name is required", nameof(modelName));
|
||||||
|
|
||||||
|
var request = new ChatCompletionRequest
|
||||||
|
{
|
||||||
|
Model = modelName,
|
||||||
|
Messages = new List<Message> { new() { Role = "user", Content = prompt } }
|
||||||
|
};
|
||||||
|
|
||||||
|
var jsonContent = new StringContent(
|
||||||
|
JsonSerializer.Serialize(request, JsonSourceGeneration.Default.ChatCompletionRequest),
|
||||||
|
Encoding.UTF8,
|
||||||
|
"application/json");
|
||||||
|
|
||||||
|
var httpRequest = new HttpRequestMessage(HttpMethod.Post, CHAT_COMPLETION_ENDPOINT)
|
||||||
|
{
|
||||||
|
Content = jsonContent
|
||||||
|
};
|
||||||
|
|
||||||
|
httpRequest.Headers.TryAddWithoutValidation("Authorization", _apiKey);
|
||||||
|
|
||||||
|
using var response = await _httpClient.SendAsync(httpRequest, cancellationToken).ConfigureAwait(false);
|
||||||
|
|
||||||
|
response.EnsureSuccessStatusCode();
|
||||||
|
|
||||||
|
var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);
|
||||||
|
|
||||||
|
var result = JsonSerializer.Deserialize(responseContent, JsonSourceGeneration.Default.ChatCompletionResponse);
|
||||||
|
|
||||||
|
if (result == null || result.Choices.Count == 0)
|
||||||
|
throw new InvalidOperationException("Failed to deserialize chat completion response");
|
||||||
|
|
||||||
|
return result.Choices[0].Message.Content;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <inheritdoc />
|
||||||
|
public async IAsyncEnumerable<string> StreamTextAsync(
|
||||||
|
string prompt,
|
||||||
|
string modelName,
|
||||||
|
[System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||||
|
{
|
||||||
|
if (string.IsNullOrWhiteSpace(prompt))
|
||||||
|
throw new ArgumentException("Prompt is required", nameof(prompt));
|
||||||
|
|
||||||
|
if (string.IsNullOrWhiteSpace(modelName))
|
||||||
|
throw new ArgumentException("Model name is required", nameof(modelName));
|
||||||
|
|
||||||
|
var request = new ChatCompletionRequest
|
||||||
|
{
|
||||||
|
Model = modelName,
|
||||||
|
Stream = true,
|
||||||
|
Messages = new List<Message> { new() { Role = "user", Content = prompt } }
|
||||||
|
};
|
||||||
|
|
||||||
|
var jsonContent = new StringContent(
|
||||||
|
JsonSerializer.Serialize(request, JsonSourceGeneration.Default.ChatCompletionRequest),
|
||||||
|
Encoding.UTF8,
|
||||||
|
"application/json");
|
||||||
|
|
||||||
|
var httpRequest = new HttpRequestMessage(HttpMethod.Post, CHAT_COMPLETION_ENDPOINT)
|
||||||
|
{
|
||||||
|
Content = jsonContent
|
||||||
|
};
|
||||||
|
|
||||||
|
httpRequest.Headers.TryAddWithoutValidation("Authorization", _apiKey);
|
||||||
|
|
||||||
|
using var response = await _httpClient.SendAsync(httpRequest, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false);
|
||||||
|
|
||||||
|
response.EnsureSuccessStatusCode();
|
||||||
|
|
||||||
|
using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
|
||||||
|
using var reader = new StreamReader(stream);
|
||||||
|
|
||||||
|
string? line;
|
||||||
|
while ((line = await reader.ReadLineAsync(cancellationToken).ConfigureAwait(false)) != null)
|
||||||
|
{
|
||||||
|
if (string.IsNullOrWhiteSpace(line) || !line.StartsWith("data: "))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
var data = line.Substring(6).Trim(); // Remove "data: " prefix
|
||||||
|
|
||||||
|
if (data == "[DONE]")
|
||||||
|
break;
|
||||||
|
|
||||||
|
var text = ParseTextFromStreamData(data);
|
||||||
|
if (!string.IsNullOrEmpty(text))
|
||||||
|
yield return text;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static string? ParseTextFromStreamData(string data)
|
||||||
|
{
|
||||||
|
try
|
||||||
|
{
|
||||||
|
using var jsonDoc = JsonDocument.Parse(data);
|
||||||
|
var choices = jsonDoc.RootElement.GetProperty("choices");
|
||||||
|
var choice = choices[0];
|
||||||
|
|
||||||
|
if (choice.TryGetProperty("delta", out var delta))
|
||||||
|
{
|
||||||
|
if (delta.TryGetProperty("content", out var content))
|
||||||
|
{
|
||||||
|
return content.GetString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (choice.TryGetProperty("text", out var text))
|
||||||
|
{
|
||||||
|
return text.GetString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
catch (JsonException)
|
||||||
|
{
|
||||||
|
// Skip malformed JSON chunks
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user