add Fireworks AI provider support
This commit is contained in:
@@ -0,0 +1,211 @@
|
||||
using System.Net.Http.Headers;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using Hush.Providers.Interfaces;
|
||||
using Hush.Providers.Models.Request;
|
||||
using Hush.Providers.Serialization;
|
||||
|
||||
namespace Hush.Providers.Providers;
|
||||
|
||||
/// <summary>
|
||||
/// Implementation of LLM provider for Fireworks AI API.
|
||||
/// </summary>
|
||||
public class FireworksProvider : IAudioToTextProvider, ITextStreamingProvider
|
||||
{
|
||||
private const string CHAT_COMPLETION_ENDPOINT = "https://api.fireworks.ai/inference/v1/chat/completions";
|
||||
private const string TRANSCRIPTION_ENDPOINT_PROD = "https://audio-prod.api.fireworks.ai/v1/audio/transcriptions";
|
||||
private const string TRANSCRIPTION_ENDPOINT_TURBO = "https://audio-turbo.api.fireworks.ai/v1/audio/transcriptions";
|
||||
|
||||
private readonly HttpClient _httpClient;
|
||||
private readonly string _apiKey;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the FireworksProvider class.
|
||||
/// </summary>
|
||||
/// <param name="apiKey">The Fireworks AI API key</param>
|
||||
/// <param name="httpClient">Optional HttpClient instance (for testing)</param>
|
||||
public FireworksProvider(string apiKey, HttpClient? httpClient = null)
|
||||
{
|
||||
_apiKey = apiKey ?? throw new ArgumentNullException(nameof(apiKey));
|
||||
_httpClient = httpClient ?? new HttpClient();
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task<string> TranscribeAsync(
|
||||
Stream audioStream,
|
||||
string modelName,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(audioStream);
|
||||
|
||||
if (string.IsNullOrWhiteSpace(modelName))
|
||||
throw new ArgumentException("Model name is required", nameof(modelName));
|
||||
|
||||
var endpoint = modelName.Contains("turbo", StringComparison.OrdinalIgnoreCase)
|
||||
? TRANSCRIPTION_ENDPOINT_TURBO
|
||||
: TRANSCRIPTION_ENDPOINT_PROD;
|
||||
|
||||
var request = new TranscriptionRequest { Model = modelName };
|
||||
|
||||
using var content = new MultipartFormDataContent();
|
||||
content.Add(new StreamContent(audioStream), "file", "audio.wav");
|
||||
content.Add(new StringContent(request.Model), "model");
|
||||
if (request.ResponseFormat != null)
|
||||
content.Add(new StringContent(request.ResponseFormat), "response_format");
|
||||
if (request.Language != null)
|
||||
content.Add(new StringContent(request.Language), "language");
|
||||
if (request.Prompt != null)
|
||||
content.Add(new StringContent(request.Prompt), "prompt");
|
||||
if (request.Temperature.HasValue)
|
||||
content.Add(new StringContent(request.Temperature.Value.ToString(System.Globalization.CultureInfo.InvariantCulture)), "temperature");
|
||||
|
||||
var httpRequest = new HttpRequestMessage(HttpMethod.Post, endpoint)
|
||||
{
|
||||
Content = content
|
||||
};
|
||||
|
||||
httpRequest.Headers.TryAddWithoutValidation("Authorization", _apiKey);
|
||||
|
||||
using var response = await _httpClient.SendAsync(httpRequest, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
response.EnsureSuccessStatusCode();
|
||||
|
||||
var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);
|
||||
|
||||
var result = JsonSerializer.Deserialize(
|
||||
responseContent,
|
||||
JsonSourceGeneration.Default.TranscriptionResponse);
|
||||
|
||||
if (result == null)
|
||||
throw new InvalidOperationException("Failed to deserialize transcription response");
|
||||
|
||||
return result.Text;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task<string> CompleteTextAsync(
|
||||
string prompt,
|
||||
string modelName,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(prompt))
|
||||
throw new ArgumentException("Prompt is required", nameof(prompt));
|
||||
|
||||
if (string.IsNullOrWhiteSpace(modelName))
|
||||
throw new ArgumentException("Model name is required", nameof(modelName));
|
||||
|
||||
var request = new ChatCompletionRequest
|
||||
{
|
||||
Model = modelName,
|
||||
Messages = new List<Message> { new() { Role = "user", Content = prompt } }
|
||||
};
|
||||
|
||||
var jsonContent = new StringContent(
|
||||
JsonSerializer.Serialize(request, JsonSourceGeneration.Default.ChatCompletionRequest),
|
||||
Encoding.UTF8,
|
||||
"application/json");
|
||||
|
||||
var httpRequest = new HttpRequestMessage(HttpMethod.Post, CHAT_COMPLETION_ENDPOINT)
|
||||
{
|
||||
Content = jsonContent
|
||||
};
|
||||
|
||||
httpRequest.Headers.TryAddWithoutValidation("Authorization", _apiKey);
|
||||
|
||||
using var response = await _httpClient.SendAsync(httpRequest, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
response.EnsureSuccessStatusCode();
|
||||
|
||||
var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);
|
||||
|
||||
var result = JsonSerializer.Deserialize(responseContent, JsonSourceGeneration.Default.ChatCompletionResponse);
|
||||
|
||||
if (result == null || result.Choices.Count == 0)
|
||||
throw new InvalidOperationException("Failed to deserialize chat completion response");
|
||||
|
||||
return result.Choices[0].Message.Content;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async IAsyncEnumerable<string> StreamTextAsync(
|
||||
string prompt,
|
||||
string modelName,
|
||||
[System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(prompt))
|
||||
throw new ArgumentException("Prompt is required", nameof(prompt));
|
||||
|
||||
if (string.IsNullOrWhiteSpace(modelName))
|
||||
throw new ArgumentException("Model name is required", nameof(modelName));
|
||||
|
||||
var request = new ChatCompletionRequest
|
||||
{
|
||||
Model = modelName,
|
||||
Stream = true,
|
||||
Messages = new List<Message> { new() { Role = "user", Content = prompt } }
|
||||
};
|
||||
|
||||
var jsonContent = new StringContent(
|
||||
JsonSerializer.Serialize(request, JsonSourceGeneration.Default.ChatCompletionRequest),
|
||||
Encoding.UTF8,
|
||||
"application/json");
|
||||
|
||||
var httpRequest = new HttpRequestMessage(HttpMethod.Post, CHAT_COMPLETION_ENDPOINT)
|
||||
{
|
||||
Content = jsonContent
|
||||
};
|
||||
|
||||
httpRequest.Headers.TryAddWithoutValidation("Authorization", _apiKey);
|
||||
|
||||
using var response = await _httpClient.SendAsync(httpRequest, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
response.EnsureSuccessStatusCode();
|
||||
|
||||
using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
|
||||
using var reader = new StreamReader(stream);
|
||||
|
||||
string? line;
|
||||
while ((line = await reader.ReadLineAsync(cancellationToken).ConfigureAwait(false)) != null)
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(line) || !line.StartsWith("data: "))
|
||||
continue;
|
||||
|
||||
var data = line.Substring(6).Trim(); // Remove "data: " prefix
|
||||
|
||||
if (data == "[DONE]")
|
||||
break;
|
||||
|
||||
var text = ParseTextFromStreamData(data);
|
||||
if (!string.IsNullOrEmpty(text))
|
||||
yield return text;
|
||||
}
|
||||
}
|
||||
|
||||
private static string? ParseTextFromStreamData(string data)
|
||||
{
|
||||
try
|
||||
{
|
||||
using var jsonDoc = JsonDocument.Parse(data);
|
||||
var choices = jsonDoc.RootElement.GetProperty("choices");
|
||||
var choice = choices[0];
|
||||
|
||||
if (choice.TryGetProperty("delta", out var delta))
|
||||
{
|
||||
if (delta.TryGetProperty("content", out var content))
|
||||
{
|
||||
return content.GetString();
|
||||
}
|
||||
}
|
||||
else if (choice.TryGetProperty("text", out var text))
|
||||
{
|
||||
return text.GetString();
|
||||
}
|
||||
}
|
||||
catch (JsonException)
|
||||
{
|
||||
// Skip malformed JSON chunks
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user