add Fireworks AI provider support

This commit is contained in:
2026-03-22 19:02:37 +01:00
parent a26d453720
commit 70e784a1cc
4 changed files with 310 additions and 73 deletions
@@ -0,0 +1,211 @@
using System.Net.Http.Headers;
using System.Text;
using System.Text.Json;
using Hush.Providers.Interfaces;
using Hush.Providers.Models.Request;
using Hush.Providers.Serialization;
namespace Hush.Providers.Providers;
/// <summary>
/// Implementation of LLM provider for Fireworks AI API.
/// </summary>
public class FireworksProvider : IAudioToTextProvider, ITextStreamingProvider
{
private const string CHAT_COMPLETION_ENDPOINT = "https://api.fireworks.ai/inference/v1/chat/completions";
private const string TRANSCRIPTION_ENDPOINT_PROD = "https://audio-prod.api.fireworks.ai/v1/audio/transcriptions";
private const string TRANSCRIPTION_ENDPOINT_TURBO = "https://audio-turbo.api.fireworks.ai/v1/audio/transcriptions";
private readonly HttpClient _httpClient;
private readonly string _apiKey;
/// <summary>
/// Initializes a new instance of the FireworksProvider class.
/// </summary>
/// <param name="apiKey">The Fireworks AI API key</param>
/// <param name="httpClient">Optional HttpClient instance (for testing)</param>
public FireworksProvider(string apiKey, HttpClient? httpClient = null)
{
_apiKey = apiKey ?? throw new ArgumentNullException(nameof(apiKey));
_httpClient = httpClient ?? new HttpClient();
}
/// <inheritdoc />
public async Task<string> TranscribeAsync(
Stream audioStream,
string modelName,
CancellationToken cancellationToken = default)
{
ArgumentNullException.ThrowIfNull(audioStream);
if (string.IsNullOrWhiteSpace(modelName))
throw new ArgumentException("Model name is required", nameof(modelName));
var endpoint = modelName.Contains("turbo", StringComparison.OrdinalIgnoreCase)
? TRANSCRIPTION_ENDPOINT_TURBO
: TRANSCRIPTION_ENDPOINT_PROD;
var request = new TranscriptionRequest { Model = modelName };
using var content = new MultipartFormDataContent();
content.Add(new StreamContent(audioStream), "file", "audio.wav");
content.Add(new StringContent(request.Model), "model");
if (request.ResponseFormat != null)
content.Add(new StringContent(request.ResponseFormat), "response_format");
if (request.Language != null)
content.Add(new StringContent(request.Language), "language");
if (request.Prompt != null)
content.Add(new StringContent(request.Prompt), "prompt");
if (request.Temperature.HasValue)
content.Add(new StringContent(request.Temperature.Value.ToString(System.Globalization.CultureInfo.InvariantCulture)), "temperature");
var httpRequest = new HttpRequestMessage(HttpMethod.Post, endpoint)
{
Content = content
};
httpRequest.Headers.TryAddWithoutValidation("Authorization", _apiKey);
using var response = await _httpClient.SendAsync(httpRequest, cancellationToken).ConfigureAwait(false);
response.EnsureSuccessStatusCode();
var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);
var result = JsonSerializer.Deserialize(
responseContent,
JsonSourceGeneration.Default.TranscriptionResponse);
if (result == null)
throw new InvalidOperationException("Failed to deserialize transcription response");
return result.Text;
}
/// <inheritdoc />
public async Task<string> CompleteTextAsync(
string prompt,
string modelName,
CancellationToken cancellationToken = default)
{
if (string.IsNullOrWhiteSpace(prompt))
throw new ArgumentException("Prompt is required", nameof(prompt));
if (string.IsNullOrWhiteSpace(modelName))
throw new ArgumentException("Model name is required", nameof(modelName));
var request = new ChatCompletionRequest
{
Model = modelName,
Messages = new List<Message> { new() { Role = "user", Content = prompt } }
};
var jsonContent = new StringContent(
JsonSerializer.Serialize(request, JsonSourceGeneration.Default.ChatCompletionRequest),
Encoding.UTF8,
"application/json");
var httpRequest = new HttpRequestMessage(HttpMethod.Post, CHAT_COMPLETION_ENDPOINT)
{
Content = jsonContent
};
httpRequest.Headers.TryAddWithoutValidation("Authorization", _apiKey);
using var response = await _httpClient.SendAsync(httpRequest, cancellationToken).ConfigureAwait(false);
response.EnsureSuccessStatusCode();
var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);
var result = JsonSerializer.Deserialize(responseContent, JsonSourceGeneration.Default.ChatCompletionResponse);
if (result == null || result.Choices.Count == 0)
throw new InvalidOperationException("Failed to deserialize chat completion response");
return result.Choices[0].Message.Content;
}
/// <inheritdoc />
public async IAsyncEnumerable<string> StreamTextAsync(
string prompt,
string modelName,
[System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (string.IsNullOrWhiteSpace(prompt))
throw new ArgumentException("Prompt is required", nameof(prompt));
if (string.IsNullOrWhiteSpace(modelName))
throw new ArgumentException("Model name is required", nameof(modelName));
var request = new ChatCompletionRequest
{
Model = modelName,
Stream = true,
Messages = new List<Message> { new() { Role = "user", Content = prompt } }
};
var jsonContent = new StringContent(
JsonSerializer.Serialize(request, JsonSourceGeneration.Default.ChatCompletionRequest),
Encoding.UTF8,
"application/json");
var httpRequest = new HttpRequestMessage(HttpMethod.Post, CHAT_COMPLETION_ENDPOINT)
{
Content = jsonContent
};
httpRequest.Headers.TryAddWithoutValidation("Authorization", _apiKey);
using var response = await _httpClient.SendAsync(httpRequest, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false);
response.EnsureSuccessStatusCode();
using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
using var reader = new StreamReader(stream);
string? line;
while ((line = await reader.ReadLineAsync(cancellationToken).ConfigureAwait(false)) != null)
{
if (string.IsNullOrWhiteSpace(line) || !line.StartsWith("data: "))
continue;
var data = line.Substring(6).Trim(); // Remove "data: " prefix
if (data == "[DONE]")
break;
var text = ParseTextFromStreamData(data);
if (!string.IsNullOrEmpty(text))
yield return text;
}
}
private static string? ParseTextFromStreamData(string data)
{
try
{
using var jsonDoc = JsonDocument.Parse(data);
var choices = jsonDoc.RootElement.GetProperty("choices");
var choice = choices[0];
if (choice.TryGetProperty("delta", out var delta))
{
if (delta.TryGetProperty("content", out var content))
{
return content.GetString();
}
}
else if (choice.TryGetProperty("text", out var text))
{
return text.GetString();
}
}
catch (JsonException)
{
// Skip malformed JSON chunks
}
return null;
}
}