-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathTextEmbeddingResponse.cs
More file actions
119 lines (111 loc) · 5.73 KB
/
TextEmbeddingResponse.cs
File metadata and controls
119 lines (111 loc) · 5.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Net.Http;
using System.Collections;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using System.IO;
namespace rosette_api
{
/// <summary>
/// Class for representing responses from the API when the TextEmbedding endpoint has been called
/// </summary>
[JsonObject(MemberSerialization.OptOut)]
public class TextEmbeddingResponse : RosetteResponse
{
/// <summary>
/// Gets the averaged text vector
/// </summary>
[JsonProperty(embeddingKey)]
public IEnumerable<double> TextEmbedding { get; set; }
/// <summary>
/// Lists the tokens in the document
/// </summary>
[JsonProperty(tokenKey)]
public IEnumerable<string> tokens { get; set; }
/// <summary>
/// Lists the token embeddings
/// </summary>
[JsonProperty(tokenEmbeddingsKey)]
public IEnumerable<List<double>> tokenEmbeddings { get; set; }
private const String embeddingKey = "documentEmbedding";
private const String tokenKey = "tokens";
private const String tokenEmbeddingsKey = "tokenEmbeddings";
/// <summary>
/// Creates a TextEmbeddingResponse from the API's raw output
/// </summary>
/// <param name="apiResult">The API's output</param>
public TextEmbeddingResponse(HttpResponseMessage apiResult)
: base(apiResult)
{
List<double> textEmbedding = new List<double>();
JArray enumerableResults = this.ContentDictionary.ContainsKey(embeddingKey) ? this.ContentDictionary[embeddingKey] as JArray : new JArray();
foreach (JValue result in enumerableResults)
{
textEmbedding.Add(result.ToObject<double>());
}
JArray tokensArr = this.ContentDictionary.ContainsKey(tokenKey) ? this.ContentDictionary[tokenKey] as JArray : null;
List<string> tokens = tokensArr != null ? new List<string>(tokensArr.Select((jToken) => jToken?.ToString())) : null;
JArray tokenEmbeddingsArr = this.ContentDictionary.ContainsKey(tokenEmbeddingsKey) ? this.ContentDictionary[tokenEmbeddingsKey] as JArray : null;
List<List<double>> tokenEmbeddings = tokenEmbeddingsArr != null ? new List<List<double>>(tokenEmbeddingsArr.Select<JToken, List<double>>((jToken) => jToken?.ToObject<List<double>>())) : null;
this.TextEmbedding = textEmbedding;
this.tokens = tokens;
this.tokenEmbeddings = tokenEmbeddings;
}
/// <summary>
/// Constructs a TextEmbedding Response from a text embedding, a collection of response headers, and content in a dictionary or content as JSON
/// </summary>
/// <param name="textEmbedding">The averaged text vector (text embedding)</param>
/// <param name="tokens">The tokens from the document</param>
/// <param name="tokenEmbeddings">The embeddings for each token</param>
/// <param name="responseHeaders">The response headers from the API</param>
/// <param name="content">The content of the response (i.e. the textEmbedding list)</param>
/// <param name="contentAsJson">The content as a JSON string</param>
public TextEmbeddingResponse(IEnumerable<double> textEmbedding, IEnumerable<string> tokens, IEnumerable<List<double>> tokenEmbeddings,
Dictionary<string, string> responseHeaders, Dictionary<string, object> content = null, String contentAsJson = null)
: base(responseHeaders, content, contentAsJson)
{
this.TextEmbedding = textEmbedding;
this.tokens = tokens;
this.tokenEmbeddings = tokenEmbeddings;
}
/// <summary>
/// Equals override
/// </summary>
/// <param name="obj">The object to compare</param>
/// <returns>True if equal</returns>
public override bool Equals(object obj)
{
if (obj is TextEmbeddingResponse)
{
TextEmbeddingResponse other = obj as TextEmbeddingResponse;
List<bool> conditions = new List<bool>() {
this.TextEmbedding != null && other.TextEmbedding != null ? this.TextEmbedding.SequenceEqual(other.TextEmbedding) : this.TextEmbedding == other.TextEmbedding,
this.tokens != null && other.tokens != null ? this.tokens.SequenceEqual(other.tokens) : this.tokens == other.tokens,
this.tokenEmbeddings != null && other.tokenEmbeddings != null ? this.tokenEmbeddings.Any(a => other.tokenEmbeddings.Any(b => a.SequenceEqual(b))) : this.tokenEmbeddings == other.tokenEmbeddings,
this.ResponseHeaders != null && other.ResponseHeaders != null ? this.ResponseHeaders.Equals(other.ResponseHeaders) : this.ResponseHeaders == other.ResponseHeaders
};
return conditions.All(condition => condition);
}
else
{
return false;
}
}
/// <summary>
/// Hashcode override
/// </summary>
/// <returns>The hashcode</returns>
public override int GetHashCode()
{
int h0 = this.ResponseHeaders != null ? this.ResponseHeaders.GetHashCode() : 1;
int h1 = this.TextEmbedding != null ? this.TextEmbedding.Aggregate<double, int>(1, (seed, item) => seed ^ item.GetHashCode()) : 1;
int h2 = this.tokens != null ? this.tokens.GetHashCode() : 1;
int h3 = this.tokenEmbeddings != null ? this.tokenEmbeddings.Aggregate<List<double>, int>(1, (seed, item) => seed ^ item.GetHashCode()) : 1;
return h0 ^ h1 ^ h2 ^ h3;
}
}
}