Skip to content

Commit

Permalink
fix: Fixed ConfiguredCancelableAsyncEnumerable usage.
Browse files Browse the repository at this point in the history
  • Loading branch information
HavenDV committed May 11, 2024
1 parent a014b07 commit fc1070a
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 31 deletions.
2 changes: 1 addition & 1 deletion src/libs/Directory.Build.props
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
</ItemGroup>

<PropertyGroup Label="Nuget">
<Version>0.9.3</Version>
<Version>0.9.4</Version>
<GeneratePackageOnBuild Condition=" '$(Configuration)' == 'Release' ">true</GeneratePackageOnBuild>
<GenerateDocumentationFile>true</GenerateDocumentationFile>
<Authors>tryAGI and contributors</Authors>
Expand Down
33 changes: 24 additions & 9 deletions src/libs/Ollama/OllamaApiClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,21 @@ public OllamaApiClient(HttpClient client)
/// </summary>
/// <param name="request">The parameters for the model to create</param>
/// <param name="cancellationToken">The token to cancel the operation with</param>
public ConfiguredCancelableAsyncEnumerable<CreateModelResponse> CreateModelAsync(
public async IAsyncEnumerable<CreateModelResponse> CreateModelAsync(
CreateModelRequest request,
CancellationToken cancellationToken = default)
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
return StreamPostAsync(
var enumerable = StreamPostAsync(
"api/create",
request,
SourceGenerationContext.Default.CreateModelRequest,
SourceGenerationContext.Default.CreateModelResponse,
cancellationToken).ConfigureAwait(false);

await foreach (var response in enumerable)
{
yield return response;
}
}

/// <summary>
Expand Down Expand Up @@ -142,33 +147,43 @@ await PostAsync(
/// </summary>
/// <param name="request">The request parameters</param>
/// <param name="cancellationToken">The token to cancel the operation with</param>
public ConfiguredCancelableAsyncEnumerable<PullModelResponse> PullModelAsync(
public async IAsyncEnumerable<PullModelResponse> PullModelAsync(
PullModelRequest request,
CancellationToken cancellationToken = default)
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
return StreamPostAsync(
var enumerable = StreamPostAsync(
"api/pull",
request,
SourceGenerationContext.Default.PullModelRequest,
SourceGenerationContext.Default.PullModelResponse,
cancellationToken).ConfigureAwait(false);

await foreach (var response in enumerable)
{
yield return response;
}
}

/// <summary>
/// Sends a request to the /api/push endpoint to push a new model
/// </summary>
/// <param name="request">The request parameters</param>
/// <param name="cancellationToken">The token to cancel the operation with</param>
public ConfiguredCancelableAsyncEnumerable<PushModelResponse> PushModelAsync(
public async IAsyncEnumerable<PushModelResponse> PushModelAsync(
PushModelRequest request,
CancellationToken cancellationToken = default)
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
return StreamPostAsync(
var enumerable = StreamPostAsync(
"api/push",
request,
SourceGenerationContext.Default.PushModelRequest,
SourceGenerationContext.Default.PushModelResponse,
cancellationToken).ConfigureAwait(false);

await foreach (var response in enumerable)
{
yield return response;
}
}

/// <summary>
Expand Down
67 changes: 47 additions & 20 deletions src/libs/Ollama/OllamaApiClientExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,19 @@ public static Chat Chat(
/// <param name="chatRequest">The request to send to Ollama</param>
/// <param name="cancellationToken">The token to cancel the operation with</param>
/// <returns>List of the returned messages including the previous context</returns>
public static ConfiguredCancelableAsyncEnumerable<GenerateChatCompletionResponse> SendChatAsync(
public static async IAsyncEnumerable<GenerateChatCompletionResponse> SendChatAsync(
this OllamaApiClient client,
GenerateChatCompletionRequest chatRequest,
CancellationToken cancellationToken = default)
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
client = client ?? throw new ArgumentNullException(nameof(client));

return client.SendChatAsync(chatRequest, cancellationToken).ConfigureAwait(false);
var enumerable = client.SendChatAsync(chatRequest, cancellationToken).ConfigureAwait(false);

await foreach (var response in enumerable)
{
yield return response;
}
}

/// <summary>
Expand Down Expand Up @@ -74,22 +79,27 @@ await client.CopyModelAsync(new CopyModelRequest
/// </param>
/// <param name="path">The name path to the model file</param>
/// <param name="cancellationToken">The token to cancel the operation with</param>
public static ConfiguredCancelableAsyncEnumerable<CreateModelResponse> CreateModelAsync(
public static async IAsyncEnumerable<CreateModelResponse> CreateModelAsync(
this OllamaApiClient client,
string name,
string modelFileContent,
string? path = null,
CancellationToken cancellationToken = default)
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
client = client ?? throw new ArgumentNullException(nameof(client));

return client.CreateModelAsync(new CreateModelRequest
var enumerable = client.CreateModelAsync(new CreateModelRequest
{
Name = name,
Modelfile = modelFileContent,
Path = path,
Stream = true,
}, cancellationToken).ConfigureAwait(false);

await foreach (var response in enumerable)
{
yield return response;
}
}

/// <summary>
Expand All @@ -98,17 +108,22 @@ public static ConfiguredCancelableAsyncEnumerable<CreateModelResponse> CreateMod
/// <param name="client"></param>
/// <param name="model">The name of the model to pull</param>
/// <param name="cancellationToken">The token to cancel the operation with</param>
public static ConfiguredCancelableAsyncEnumerable<PullModelResponse> PullModelAsync(
public static async IAsyncEnumerable<PullModelResponse> PullModelAsync(
this OllamaApiClient client,
string model,
CancellationToken cancellationToken = default)
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
client = client ?? throw new ArgumentNullException(nameof(client));

return client.PullModelAsync(new PullModelRequest
var enumerable = client.PullModelAsync(new PullModelRequest
{
Name = model,
}, cancellationToken).ConfigureAwait(false);

await foreach (var response in enumerable)
{
yield return response;
}
}

/// <summary>
Expand All @@ -117,18 +132,23 @@ public static ConfiguredCancelableAsyncEnumerable<PullModelResponse> PullModelAs
/// <param name="client"></param>
/// <param name="name">The name of the model to push</param>
/// <param name="cancellationToken">The token to cancel the operation with</param>
public static ConfiguredCancelableAsyncEnumerable<PushModelResponse> PushModelAsync(
public static async IAsyncEnumerable<PushModelResponse> PushModelAsync(
this OllamaApiClient client,
string name,
CancellationToken cancellationToken = default)
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
client = client ?? throw new ArgumentNullException(nameof(client));

return client.PushModelAsync(new PushModelRequest
var enumerable = client.PushModelAsync(new PushModelRequest
{
Name = name,
Stream = true,
}, cancellationToken).ConfigureAwait(false);

await foreach (var response in enumerable)
{
yield return response;
}
}

/// <summary>
Expand Down Expand Up @@ -163,25 +183,28 @@ public static async Task<GenerateEmbeddingResponse> GenerateEmbeddingsAsync(
/// </param>
/// <param name="stream"></param>
/// <param name="cancellationToken">The token to cancel the operation with</param>
public static ConfiguredCancelableAsyncEnumerable<GenerateCompletionResponse> GetCompletionAsync(
public static async IAsyncEnumerable<GenerateCompletionResponse> GetCompletionAsync(
this OllamaApiClient client,
string model,
string prompt,
bool stream = true,
IList<long>? context = null,
CancellationToken cancellationToken = default)
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
client = client ?? throw new ArgumentNullException(nameof(client));

var request = new GenerateCompletionRequest
var enumerable = client.GetCompletionAsync(new GenerateCompletionRequest
{
Prompt = prompt,
Model = model,
Stream = stream,
Context = context ?? [],
};

return client.GetCompletionAsync(request, cancellationToken).ConfigureAwait(false);
}, cancellationToken).ConfigureAwait(false);

await foreach (var response in enumerable)
{
yield return response;
}
}

/// <summary>
Expand All @@ -190,8 +213,10 @@ public static ConfiguredCancelableAsyncEnumerable<GenerateCompletionResponse> Ge
/// <param name="enumerable"></param>
/// <returns></returns>
public static async Task<GenerateCompletionResponse> WaitAsync(
this ConfiguredCancelableAsyncEnumerable<GenerateCompletionResponse> enumerable)
this IAsyncEnumerable<GenerateCompletionResponse> enumerable)
{
enumerable = enumerable ?? throw new ArgumentNullException(nameof(enumerable));

var text = string.Empty;
var currentResponse = new GenerateCompletionResponse();
await foreach (var response in enumerable)
Expand Down Expand Up @@ -242,8 +267,10 @@ public static async Task<GenerateChatCompletionResponse> WaitAsync(
/// <param name="enumerable"></param>
/// <returns></returns>
public static async Task<T> WaitAsync<T>(
this ConfiguredCancelableAsyncEnumerable<T> enumerable) where T : new()
this IAsyncEnumerable<T> enumerable) where T : new()
{
enumerable = enumerable ?? throw new ArgumentNullException(nameof(enumerable));

var currentResponse = new T();
await foreach (var response in enumerable)
{
Expand Down
2 changes: 1 addition & 1 deletion src/tests/Ollama.IntegrationTests/Tests.Integration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ private static async Task<Environment> PrepareEnvironmentAsync(EnvironmentType e
{
case EnvironmentType.Local:
{
using var client = new HttpClient();
var client = new HttpClient();
client.BaseAddress = new Uri("http://172.16.50.107:11434/");
var apiClient = new OllamaApiClient(client);

Expand Down

0 comments on commit fc1070a

Please sign in to comment.