-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathwc_mlkem.h
More file actions
370 lines (321 loc) · 11.9 KB
/
wc_mlkem.h
File metadata and controls
370 lines (321 loc) · 11.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
/* wc_mlkem.h
*
* Copyright (C) 2006-2025 wolfSSL Inc.
*
* This file is part of wolfSSL.
*
* wolfSSL is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* wolfSSL is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1335, USA
*/
/*!
\file wolfssl/wolfcrypt/wc_mlkem.h
*/
#ifndef WOLF_CRYPT_WC_MLKEM_H
#define WOLF_CRYPT_WC_MLKEM_H
#include <wolfssl/wolfcrypt/types.h>
#include <wolfssl/wolfcrypt/random.h>
#include <wolfssl/wolfcrypt/sha3.h>
#include <wolfssl/wolfcrypt/mlkem.h>
#ifdef WOLFSSL_HAVE_MLKEM
#ifdef WOLFSSL_KYBER_NO_MAKE_KEY
#define WOLFSSL_MLKEM_NO_MAKE_KEY
#endif
#ifdef WOLFSSL_KYBER_NO_ENCAPSULATE
#define WOLFSSL_MLKEM_NO_ENCAPSULATE
#endif
#ifdef WOLFSSL_KYBER_NO_DECAPSULATE
#define WOLFSSL_MLKEM_NO_DECAPSULATE
#endif
#ifdef noinline
#define MLKEM_NOINLINE noinline
#elif defined(_MSC_VER)
#define MLKEM_NOINLINE __declspec(noinline)
#elif defined(__GNUC__)
#define MLKEM_NOINLINE __attribute__((noinline))
#else
#define MLKEM_NOINLINE
#endif
enum {
/* Flags of Kyber keys. */
MLKEM_FLAG_PRIV_SET = 0x0001,
MLKEM_FLAG_PUB_SET = 0x0002,
MLKEM_FLAG_BOTH_SET = 0x0003,
MLKEM_FLAG_H_SET = 0x0004,
MLKEM_FLAG_A_SET = 0x0008,
/* 2 bits of random used to create noise value. */
MLKEM_CBD_ETA2 = 2,
/* 3 bits of random used to create noise value. */
MLKEM_CBD_ETA3 = 3,
/* Number of bits to compress to. */
MLKEM_COMP_4BITS = 4,
MLKEM_COMP_5BITS = 5,
MLKEM_COMP_10BITS = 10,
MLKEM_COMP_11BITS = 11,
};
/* SHAKE128 rate. */
#define XOF_BLOCK_SIZE 168
/* Modulus of co-efficients of polynomial. */
#define MLKEM_Q 3329
/* Kyber-512 parameters */
#ifdef WOLFSSL_WC_ML_KEM_512
/* Number of bits of random to create noise from. */
#define WC_ML_KEM_512_ETA1 MLKEM_CBD_ETA3
#endif /* WOLFSSL_WC_ML_KEM_512 */
/* Kyber-768 parameters */
#ifdef WOLFSSL_WC_ML_KEM_768
/* Number of bits of random to create noise from. */
#define WC_ML_KEM_768_ETA1 MLKEM_CBD_ETA2
#endif /* WOLFSSL_WC_ML_KEM_768 */
/* Kyber-1024 parameters */
#ifdef WOLFSSL_WC_ML_KEM_1024
/* Number of bits of random to create noise from. */
#define WC_ML_KEM_1024_ETA1 MLKEM_CBD_ETA2
#endif /* WOLFSSL_KYBER1024 */
/* The data type of the hash function. */
#define MLKEM_HASH_T wc_Sha3
/* The data type of the pseudo-random function. */
#define MLKEM_PRF_T wc_Shake
/* ML-KEM key. */
struct MlKemKey {
/* Type of key: WC_ML_KEM_512, WC_ML_KEM_768, WC_ML_KEM_1024 */
int type;
/* Dynamic memory allocation hint. */
void* heap;
#if defined(WOLF_CRYPTO_CB)
/* Device Id. */
int devId;
#endif
/* Flags indicating what is stored in the key. */
int flags;
/* A pseudo-random function object. */
MLKEM_HASH_T hash;
/* A pseudo-random function object. */
MLKEM_PRF_T prf;
/* Private key as a vector. */
sword16 priv[WC_ML_KEM_MAX_K * MLKEM_N];
/* Public key as a vector. */
sword16 pub[WC_ML_KEM_MAX_K * MLKEM_N];
/* Public seed. */
byte pubSeed[WC_ML_KEM_SYM_SZ];
/* Public hash - hash of encoded public key. */
byte h[WC_ML_KEM_SYM_SZ];
/* Randomizer for decapsulation. */
byte z[WC_ML_KEM_SYM_SZ];
#ifdef WOLFSSL_MLKEM_CACHE_A
/* A matrix from key generation. */
sword16 a[WC_ML_KEM_MAX_K * WC_ML_KEM_MAX_K * MLKEM_N];
#endif
};
#ifdef __cplusplus
extern "C" {
#endif
WOLFSSL_LOCAL
void mlkem_init(void);
#ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM
WOLFSSL_LOCAL
void mlkem_keygen(sword16* priv, sword16* pub, sword16* e, const sword16* a,
int kp);
#else
WOLFSSL_LOCAL
int mlkem_keygen_seeds(sword16* priv, sword16* pub, MLKEM_PRF_T* prf,
sword16* e, int kp, byte* seed, byte* noiseSeed);
#endif
#ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
WOLFSSL_LOCAL
void mlkem_encapsulate(const sword16* pub, sword16* bp, sword16* v,
const sword16* at, sword16* sp, const sword16* ep, const sword16* epp,
const sword16* m, int kp);
#else
WOLFSSL_LOCAL
int mlkem_encapsulate_seeds(const sword16* pub, MLKEM_PRF_T* prf, sword16* bp,
sword16* tp, sword16* sp, int kp, const byte* msg, byte* seed,
byte* coins);
#endif
WOLFSSL_LOCAL
void mlkem_decapsulate(const sword16* priv, sword16* mp, sword16* bp,
const sword16* v, int kp);
WOLFSSL_LOCAL
int mlkem_gen_matrix(MLKEM_PRF_T* prf, sword16* a, int kp, byte* seed,
int transposed);
WOLFSSL_LOCAL
int mlkem_get_noise(MLKEM_PRF_T* prf, int kp, sword16* vec1, sword16* vec2,
sword16* poly, byte* seed);
#if defined(USE_INTEL_SPEEDUP) || \
(defined(WOLFSSL_ARMASM) && defined(__aarch64__))
WOLFSSL_LOCAL
int mlkem_kdf(byte* seed, int seedLen, byte* out, int outLen);
#endif
WOLFSSL_LOCAL
void mlkem_hash_init(MLKEM_HASH_T* hash);
WOLFSSL_LOCAL
int mlkem_hash_new(MLKEM_HASH_T* hash, void* heap, int devId);
WOLFSSL_LOCAL
void mlkem_hash_free(MLKEM_HASH_T* hash);
WOLFSSL_LOCAL
int mlkem_hash256(wc_Sha3* hash, const byte* data, word32 dataLen, byte* out);
WOLFSSL_LOCAL
int mlkem_hash512(wc_Sha3* hash, const byte* data1, word32 data1Len,
const byte* data2, word32 data2Len, byte* out);
WOLFSSL_LOCAL
int mlkem_derive_secret(MLKEM_PRF_T* prf, const byte* z, const byte* ct,
word32 ctSz, byte* ss);
WOLFSSL_LOCAL
void mlkem_prf_init(MLKEM_PRF_T* prf);
WOLFSSL_LOCAL
int mlkem_prf_new(MLKEM_PRF_T* prf, void* heap, int devId);
WOLFSSL_LOCAL
void mlkem_prf_free(MLKEM_PRF_T* prf);
WOLFSSL_LOCAL
int mlkem_cmp(const byte* a, const byte* b, int sz);
WOLFSSL_LOCAL
void mlkem_vec_compress_10(byte* r, sword16* v, unsigned int kp);
WOLFSSL_LOCAL
void mlkem_vec_compress_11(byte* r, sword16* v);
WOLFSSL_LOCAL
void mlkem_vec_decompress_10(sword16* v, const unsigned char* b,
unsigned int kp);
WOLFSSL_LOCAL
void mlkem_vec_decompress_11(sword16* v, const unsigned char* b);
WOLFSSL_LOCAL
void mlkem_compress_4(byte* b, sword16* p);
WOLFSSL_LOCAL
void mlkem_compress_5(byte* b, sword16* p);
WOLFSSL_LOCAL
void mlkem_decompress_4(sword16* p, const unsigned char* b);
WOLFSSL_LOCAL
void mlkem_decompress_5(sword16* p, const unsigned char* b);
WOLFSSL_LOCAL
void mlkem_from_msg(sword16* p, const byte* msg);
WOLFSSL_LOCAL
void mlkem_to_msg(byte* msg, sword16* p);
WOLFSSL_LOCAL
void mlkem_from_bytes(sword16* p, const byte* b, int k);
WOLFSSL_LOCAL
void mlkem_to_bytes(byte* b, sword16* p, int k);
#ifdef USE_INTEL_SPEEDUP
WOLFSSL_LOCAL
void mlkem_keygen_avx2(sword16* priv, sword16* pub, sword16* e,
const sword16* a, int kp);
WOLFSSL_LOCAL
void mlkem_encapsulate_avx2(const sword16* pub, sword16* bp, sword16* v,
const sword16* at, sword16* sp, const sword16* ep, const sword16* epp,
const sword16* m, int kp);
WOLFSSL_LOCAL
void mlkem_decapsulate_avx2(const sword16* priv, sword16* mp, sword16* bp,
const sword16* v, int kp);
WOLFSSL_LOCAL
unsigned int mlkem_rej_uniform_n_avx2(sword16* p, unsigned int len,
const byte* r, unsigned int rLen);
WOLFSSL_LOCAL
unsigned int mlkem_rej_uniform_avx2(sword16* p, unsigned int len, const byte* r,
unsigned int rLen);
WOLFSSL_LOCAL
void mlkem_redistribute_21_rand_avx2(const word64* s, byte* r0, byte* r1,
byte* r2, byte* r3);
void mlkem_redistribute_17_rand_avx2(const word64* s, byte* r0, byte* r1,
byte* r2, byte* r3);
void mlkem_redistribute_16_rand_avx2(const word64* s, byte* r0, byte* r1,
byte* r2, byte* r3);
void mlkem_redistribute_8_rand_avx2(const word64* s, byte* r0, byte* r1,
byte* r2, byte* r3);
WOLFSSL_LOCAL
void mlkem_cbd_eta2_avx2(sword16* p, const byte* r);
WOLFSSL_LOCAL
void mlkem_cbd_eta3_avx2(sword16* p, const byte* r);
WOLFSSL_LOCAL
void mlkem_from_msg_avx2(sword16* p, const byte* msg);
WOLFSSL_LOCAL
void mlkem_to_msg_avx2(byte* msg, sword16* p);
WOLFSSL_LOCAL
void mlkem_from_bytes_avx2(sword16* p, const byte* b);
WOLFSSL_LOCAL
void mlkem_to_bytes_avx2(byte* b, sword16* p);
WOLFSSL_LOCAL
void mlkem_compress_10_avx2(byte* r, const sword16* p, int n);
WOLFSSL_LOCAL
void mlkem_decompress_10_avx2(sword16* p, const byte* r, int n);
WOLFSSL_LOCAL
void mlkem_compress_11_avx2(byte* r, const sword16* p, int n);
WOLFSSL_LOCAL
void mlkem_decompress_11_avx2(sword16* p, const byte* r, int n);
WOLFSSL_LOCAL
void mlkem_compress_4_avx2(byte* r, const sword16* p);
WOLFSSL_LOCAL
void mlkem_decompress_4_avx2(sword16* p, const byte* r);
WOLFSSL_LOCAL
void mlkem_compress_5_avx2(byte* r, const sword16* p);
WOLFSSL_LOCAL
void mlkem_decompress_5_avx2(sword16* p, const byte* r);
WOLFSSL_LOCAL
int mlkem_cmp_avx2(const byte* a, const byte* b, int sz);
#elif defined(__aarch64__) && defined(WOLFSSL_ARMASM)
WOLFSSL_LOCAL void mlkem_ntt(sword16* r);
WOLFSSL_LOCAL void mlkem_invntt(sword16* r);
WOLFSSL_LOCAL void mlkem_ntt_sqrdmlsh(sword16* r);
WOLFSSL_LOCAL void mlkem_invntt_sqrdmlsh(sword16* r);
WOLFSSL_LOCAL void mlkem_basemul_mont(sword16* r, const sword16* a,
const sword16* b);
WOLFSSL_LOCAL void mlkem_basemul_mont_add(sword16* r, const sword16* a,
const sword16* b);
WOLFSSL_LOCAL void mlkem_add_reduce(sword16* r, const sword16* a);
WOLFSSL_LOCAL void mlkem_add3_reduce(sword16* r, const sword16* a,
const sword16* b);
WOLFSSL_LOCAL void mlkem_rsub_reduce(sword16* r, const sword16* a);
WOLFSSL_LOCAL void mlkem_to_mont(sword16* p);
WOLFSSL_LOCAL void mlkem_to_mont_sqrdmlsh(sword16* p);
WOLFSSL_LOCAL void mlkem_sha3_blocksx3_neon(word64* state);
WOLFSSL_LOCAL void mlkem_shake128_blocksx3_seed_neon(word64* state, byte* seed);
WOLFSSL_LOCAL void mlkem_shake256_blocksx3_seed_neon(word64* state, byte* seed);
WOLFSSL_LOCAL unsigned int mlkem_rej_uniform_neon(sword16* p, unsigned int len,
const byte* r, unsigned int rLen);
WOLFSSL_LOCAL int mlkem_cmp_neon(const byte* a, const byte* b, int sz);
WOLFSSL_LOCAL void mlkem_csubq_neon(sword16* p);
WOLFSSL_LOCAL void mlkem_from_msg_neon(sword16* p, const byte* msg);
WOLFSSL_LOCAL void mlkem_to_msg_neon(byte* msg, sword16* p);
#elif defined(WOLFSSL_ARMASM_THUMB2) && defined(WOLFSSL_ARMASM)
#define mlkem_ntt mlkem_thumb2_ntt
#define mlkem_invntt mlkem_thumb2_invntt
#define mlkem_basemul_mont mlkem_thumb2_basemul_mont
#define mlkem_basemul_mont_add mlkem_thumb2_basemul_mont_add
#define mlkem_rej_uniform_c mlkem_thumb2_rej_uniform
WOLFSSL_LOCAL void mlkem_thumb2_ntt(sword16* r);
WOLFSSL_LOCAL void mlkem_thumb2_invntt(sword16* r);
WOLFSSL_LOCAL void mlkem_thumb2_basemul_mont(sword16* r, const sword16* a,
const sword16* b);
WOLFSSL_LOCAL void mlkem_thumb2_basemul_mont_add(sword16* r, const sword16* a,
const sword16* b);
WOLFSSL_LOCAL void mlkem_thumb2_csubq(sword16* p);
WOLFSSL_LOCAL unsigned int mlkem_thumb2_rej_uniform(sword16* p,
unsigned int len, const byte* r, unsigned int rLen);
#elif defined(WOLFSSL_ARMASM)
#define mlkem_ntt mlkem_arm32_ntt
#define mlkem_invntt mlkem_arm32_invntt
#define mlkem_basemul_mont mlkem_arm32_basemul_mont
#define mlkem_basemul_mont_add mlkem_arm32_basemul_mont_add
#define mlkem_rej_uniform_c mlkem_arm32_rej_uniform
WOLFSSL_LOCAL void mlkem_arm32_ntt(sword16* r);
WOLFSSL_LOCAL void mlkem_arm32_invntt(sword16* r);
WOLFSSL_LOCAL void mlkem_arm32_basemul_mont(sword16* r, const sword16* a,
const sword16* b);
WOLFSSL_LOCAL void mlkem_arm32_basemul_mont_add(sword16* r, const sword16* a,
const sword16* b);
WOLFSSL_LOCAL void mlkem_arm32_csubq(sword16* p);
WOLFSSL_LOCAL unsigned int mlkem_arm32_rej_uniform(sword16* p, unsigned int len,
const byte* r, unsigned int rLen);
#endif
#ifdef __cplusplus
} /* extern "C" */
#endif
#endif /* WOLFSSL_HAVE_MLKEM */
#endif /* WOLF_CRYPT_WC_MLKEM_H */