forked from AnswerDotAI/gpu.cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgpt2_webgpu.cpp
More file actions
769 lines (695 loc) · 32.3 KB
/
gpt2_webgpu.cpp
File metadata and controls
769 lines (695 loc) · 32.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
#include "gpu.hpp"
#include "ops.hpp"
/*
This file trains the GPT-2 model.
This version is the clean, minimal, reference. As such:
- it runs on CPU.
- it does not make the code too complex; it is readable.
- it does not use any processor-specific instructions, intrinsics and such.
- it _does_ use a few OpenMP pragmas because this is a large speedup at very low cost
There will be other versions of this code that specialize it and make it fast.
*/
#include <stdio.h>
#include <stdlib.h>
#include <ctype.h>
#include <stdint.h>
#include <assert.h>
#include <math.h>
#include <time.h>
#include <string.h>
#include <unistd.h>
#include <memory>
#ifdef OMP
#include <omp.h>
#endif
// our own utilities
// defines: fopenCheck, freadCheck, fcloseCheck, fseekCheck, mallocCheck
#include "llmc/utils.h"
// defines: tokenizer_init, tokenizer_decode, tokenizer_free
#include "llmc/tokenizer.h"
// defines: dataloader_init, dataloader_reset, dataloader_next_batch, dataloader_free
#include "llmc/dataloader.h"
using namespace gpu;
// ----------------------------------------------------------------------------
// GPT-2 model definition
typedef struct {
int max_seq_len; // max sequence length, e.g. 1024
int vocab_size; // vocab size, e.g. 50257
int padded_vocab_size; // padded to e.g. %128==0, 50304
int num_layers; // number of layers, e.g. 12
int num_heads; // number of heads in attention, e.g. 12
int channels; // number of channels, e.g. 768
} GPT2Config;
// the parameters of the model
#define NUM_PARAMETER_TENSORS 16
typedef struct {
float* wte; // (V, C)
float* wpe; // (maxT, C)
float* ln1w; // (L, C)
float* ln1b; // (L, C)
float* qkvw; // (L, 3*C, C)
float* qkvb; // (L, 3*C)
float* attprojw; // (L, C, C)
float* attprojb; // (L, C)
float* ln2w; // (L, C)
float* ln2b; // (L, C)
float* fcw; // (L, 4*C, C)
float* fcb; // (L, 4*C)
float* fcprojw; // (L, C, 4*C)
float* fcprojb; // (L, C)
float* lnfw; // (C)
float* lnfb; // (C)
} ParameterTensors;
void fill_in_parameter_sizes(size_t* param_sizes, GPT2Config config) {
size_t Vp = config.padded_vocab_size;
size_t C = config.channels;
size_t maxT = config.max_seq_len;
size_t L = config.num_layers;
param_sizes[0] = Vp * C; // wte
param_sizes[1] = maxT * C; // wpe
param_sizes[2] = L * C; // ln1w
param_sizes[3] = L * C; // ln1b
param_sizes[4] = L * (3 * C) * C; // qkvw
param_sizes[5] = L * (3 * C); // qkvb
param_sizes[6] = L * C * C; // attprojw
param_sizes[7] = L * C; // attprojb
param_sizes[8] = L * C; // ln2w
param_sizes[9] = L * C; // ln2b
param_sizes[10] = L * (4 * C) * C; // fcw
param_sizes[11] = L * (4 * C); // fcb
param_sizes[12] = L * C * (4 * C); // fcprojw
param_sizes[13] = L * C; // fcprojb
param_sizes[14] = C; // lnfw
param_sizes[15] = C; // lnfb
}
// allocate memory for the parameters and point the individual tensors to the right places
float* malloc_and_point_parameters(ParameterTensors* params, size_t* param_sizes) {
size_t num_parameters = 0;
for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) {
num_parameters += param_sizes[i];
}
// malloc all parameters all at once
float* params_memory = (float*)mallocCheck(num_parameters * sizeof(float));
// assign all the tensors
float** ptrs[] = {
¶ms->wte, ¶ms->wpe, ¶ms->ln1w, ¶ms->ln1b, ¶ms->qkvw, ¶ms->qkvb,
¶ms->attprojw, ¶ms->attprojb, ¶ms->ln2w, ¶ms->ln2b, ¶ms->fcw, ¶ms->fcb,
¶ms->fcprojw, ¶ms->fcprojb, ¶ms->lnfw, ¶ms->lnfb
};
float* params_memory_iterator = params_memory;
for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) {
*(ptrs[i]) = params_memory_iterator;
params_memory_iterator += param_sizes[i];
}
return params_memory;
}
#define NUM_ACTIVATION_TENSORS 23
typedef struct {
float* encoded; // (B, T, C)
float* ln1; // (L, B, T, C)
float* ln1_mean; // (L, B, T)
float* ln1_rstd; // (L, B, T)
float* qkv; // (L, B, T, 3*C)
float* atty; // (L, B, T, C)
float* preatt; // (L, B, NH, T, T)
float* att; // (L, B, NH, T, T)
float* attproj; // (L, B, T, C)
float* residual2; // (L, B, T, C)
float* ln2; // (L, B, T, C)
float* ln2_mean; // (L, B, T)
float* ln2_rstd; // (L, B, T)
float* fch; // (L, B, T, 4*C)
float* fch_gelu; // (L, B, T, 4*C)
float* fcproj; // (L, B, T, C)
float* residual3; // (L, B, T, C)
float* lnf; // (B, T, C)
float* lnf_mean; // (B, T)
float* lnf_rstd; // (B, T)
float* logits; // (B, T, V)
float* probs; // (B, T, V)
float* losses; // (B, T)
} ActivationTensors;
void fill_in_activation_sizes(size_t* act_sizes, GPT2Config config, int B, int T) {
size_t C = config.channels;
size_t NH = config.num_heads;
size_t L = config.num_layers;
size_t Vp = config.padded_vocab_size;
act_sizes[0] = B * T * C; // encoded
act_sizes[1] = L * B * T * C; // ln1
act_sizes[2] = L * B * T; // ln1_mean
act_sizes[3] = L * B * T; // ln1_rstd
act_sizes[4] = L * B * T * 3 * C; // qkv
act_sizes[5] = L * B * T * C; // atty
act_sizes[6] = L * B * NH * T * T; // preatt
act_sizes[7] = L * B * NH * T * T; // att
act_sizes[8] = L * B * T * C; // attproj
act_sizes[9] = L * B * T * C; // residual2
act_sizes[10] = L * B * T * C; // ln2
act_sizes[11] = L * B * T; // ln2_mean
act_sizes[12] = L * B * T; // ln2_rstd
act_sizes[13] = L * B * T * 4 * C; // fch
act_sizes[14] = L * B * T * 4 * C; // fch_gelu
act_sizes[15] = L * B * T * C; // fcproj
act_sizes[16] = L * B * T * C; // residual3
act_sizes[17] = B * T * C; // lnf
act_sizes[18] = B * T; // lnf_mean
act_sizes[19] = B * T; // lnf_rstd
act_sizes[20] = B * T * Vp; // logits
act_sizes[21] = B * T * Vp; // probs
act_sizes[22] = B * T; // losses
}
float* malloc_and_point_activations(ActivationTensors* acts, size_t* act_sizes) {
size_t num_activations = 0;
for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {
num_activations += act_sizes[i];
}
float* acts_memory = (float*)mallocCheck(num_activations * sizeof(float));
float** ptrs[] = {
&acts->encoded, &acts->ln1, &acts->ln1_mean, &acts->ln1_rstd, &acts->qkv, &acts->atty,
&acts->preatt, &acts->att, &acts->attproj, &acts->residual2, &acts->ln2, &acts->ln2_mean,
&acts->ln2_rstd, &acts->fch, &acts->fch_gelu, &acts->fcproj, &acts->residual3, &acts->lnf,
&acts->lnf_mean, &acts->lnf_rstd, &acts->logits, &acts->probs, &acts->losses
};
float* acts_memory_iterator = acts_memory;
for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {
*(ptrs[i]) = acts_memory_iterator;
acts_memory_iterator += act_sizes[i];
}
return acts_memory;
}
struct GPUParameters {
Tensor data[NUM_PARAMETER_TENSORS];
};
struct GPUActivations {
Tensor data[NUM_ACTIVATION_TENSORS];
};
void gpu_alloc(Context& ctx, Tensor* tensors, size_t* sizes, size_t n) {
for (size_t i = 0; i < n; i++) {
tensors[i] = createTensor(ctx, Shape{sizes[i]}, kf32);
}
}
typedef struct {
GPT2Config config;
// the weights (parameters) of the model, and their sizes
ParameterTensors params;
GPUParameters params_; // TODO(avh): eventually this replaces params
size_t param_sizes[NUM_PARAMETER_TENSORS];
float* params_memory;
size_t num_parameters;
// gradients of the weights
ParameterTensors grads;
float* grads_memory;
// buffers for the AdamW optimizer
float* m_memory;
float* v_memory;
// the activations of the model, and their sizes
ActivationTensors acts;
GPUActivations acts_; // TODO(avh): eventually this replaces params
size_t act_sizes[NUM_ACTIVATION_TENSORS];
float* acts_memory;
size_t num_activations;
// gradients of the activations
ActivationTensors grads_acts;
float* grads_acts_memory;
// other run state configuration
int batch_size; // the batch size (B) of current forward pass
int seq_len; // the sequence length (T) of current forward pass
int* inputs; // the input tokens for the current forward pass
int* targets; // the target tokens for the current forward pass
float mean_loss; // after a forward pass with targets, will be populated with the mean loss
} GPT2;
void gpt2_build_from_checkpoint(Context& ctx, GPT2 *model, const char* checkpoint_path) {
// read in model from a checkpoint file
FILE *model_file = fopenCheck(checkpoint_path, "rb");
int model_header[256];
freadCheck(model_header, sizeof(int), 256, model_file);
if (model_header[0] != 20240326) { printf("Bad magic model file\n"); exit(1); }
if (model_header[1] != 3) {
printf("Bad version in model file\n");
printf("---> HINT: try to re-run `python train_gpt2.py`\n");
exit(1);
}
// read in hyperparameters
size_t maxT, V, Vp, L, NH, C; // size_t to prevent int overflow
model->config.max_seq_len = maxT = model_header[2];
model->config.vocab_size = V = model_header[3];
#ifdef __EMSCRIPTEN__
model->config.num_layers = L = 12; // TODO(avh): Debugging only hack - revert this
#else
model->config.num_layers = L = model_header[4];
#endif
model->config.num_heads = NH = model_header[5];
model->config.channels = C = model_header[6];
model->config.padded_vocab_size = Vp = model_header[7];
printf("[GPT-2]\n");
printf("max_seq_len: %zu\n", maxT);
printf("vocab_size: %zu\n", V);
printf("padded_vocab_size: %zu\n", Vp);
printf("num_layers: %zu\n", L);
printf("num_heads: %zu\n", NH);
printf("channels: %zu\n", C);
// allocate space for all the parameters and read them in
fill_in_parameter_sizes(model->param_sizes, model->config);
// count the number of parameters
size_t num_parameters = 0;
for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) {
num_parameters += model->param_sizes[i];
}
printf("num_parameters: %zu\n", num_parameters);
model->num_parameters = num_parameters;
// read in all the parameters from file
model->params_memory = malloc_and_point_parameters(&model->params, model->param_sizes);
freadCheck(model->params_memory, sizeof(float), num_parameters, model_file);
fcloseCheck(model_file);
// other inits
model->acts_memory = NULL;
model->grads_memory = NULL;
model->m_memory = NULL;
model->v_memory = NULL;
model->grads_acts_memory = NULL;
model->inputs = NULL;
model->targets = NULL;
model->batch_size = 0;
model->seq_len = 0;
model->mean_loss = -1.0f; // -1.0f will designate no loss
// TODO(avh): this is just a resource test for now, eventually deprecate CPU allocations
gpu_alloc(ctx, model->params_.data, model->param_sizes, NUM_PARAMETER_TENSORS);
}
void gpt2_forward(Context& ctx, GPT2 *model, int* inputs, int* targets, size_t B, size_t T) {
// targets are optional and could be NULL
// ensure the model was initialized or error out
if (model->params_memory == NULL) {
printf("Error: model was not initialized properly.\n");
exit(1);
}
// convenience parameters (size_t to help prevent int overflow)
size_t V = model->config.vocab_size;
size_t Vp = model->config.padded_vocab_size;
size_t L = model->config.num_layers;
size_t NH = model->config.num_heads;
size_t C = model->config.channels;
// validate inputs, all indices must be in the range [0, V)
for(int i = 0; i < B * T; i++) {
assert(0 <= inputs[i] && inputs[i] < V);
if (targets != NULL) {
assert(0 <= targets[i] && targets[i] < V);
}
}
// allocate space for all the activations if needed (done here, lazily)
if(model->acts_memory == NULL) {
// record the current B,T as well
model->batch_size = B;
model->seq_len = T;
// and now allocate the space
fill_in_activation_sizes(model->act_sizes, model->config, B, T);
// TODO(avh): this is just a resource test for now, eventually deprecate CPU allocations
gpu_alloc(ctx, model->acts_.data, model->act_sizes, NUM_PARAMETER_TENSORS);
size_t num_activations = 0;
for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {
num_activations += model->act_sizes[i];
}
printf("num_activations: %zu\n", num_activations);
model->num_activations = num_activations;
printf("Allocating %.2f MB for activations\n", num_activations * sizeof(float) / (1024.0f * 1024.0f));
model->acts_memory = malloc_and_point_activations(&model->acts, model->act_sizes);
// also create memory for caching inputs and targets
model->inputs = (int*)mallocCheck(B * T * sizeof(int));
model->targets = (int*)mallocCheck(B * T * sizeof(int)); // might be unused if we never have targets but it's small
} else {
// validate B,T is consistent with how we've allocated the memory before
// in principle we could get more clever here in the future, for now this is safest
if (B != model->batch_size || T != model->seq_len) {
printf("Model: B=%d T=%d, Desired: B=%d T=%d\n", model->batch_size, model->seq_len, (int)B, (int)T);
exit(EXIT_FAILURE);
}
}
printf("Cache inputs/targets\n");
// cache the inputs/targets
memcpy(model->inputs, inputs, B * T * sizeof(int));
if (targets != NULL) {
memcpy(model->targets, targets, B * T * sizeof(int));
}
printf("Forward pass\n");
// forward pass
ParameterTensors params = model->params; // for brevity
ActivationTensors acts = model->acts;
float* residual;
printf("Encoding\n");
printf("inputs[0] = %d\n", inputs[0]);
encoder_forward(ctx, acts.encoded, inputs, params.wte, params.wpe, B, T, C); // encoding goes into residual[0]
for (int l = 0; l < L; l++) {
printf("Forward Pass Layer %d\n", l);
residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C;
// get the pointers of the weights for this layer
float* l_ln1w = params.ln1w + l * C;
float* l_ln1b = params.ln1b + l * C;
float* l_qkvw = params.qkvw + l * 3*C * C;
float* l_qkvb = params.qkvb + l * 3*C;
float* l_attprojw = params.attprojw + l * C * C;
float* l_attprojb = params.attprojb + l * C;
float* l_ln2w = params.ln2w + l * C;
float* l_ln2b = params.ln2b + l * C;
float* l_fcw = params.fcw + l * 4*C * C;
float* l_fcb = params.fcb + l * 4*C;
float* l_fcprojw = params.fcprojw + l * C * 4*C;
float* l_fcprojb = params.fcprojb + l * C;
// get the pointers of the activations for this layer
float* l_ln1 = acts.ln1 + l * B * T * C;
float* l_ln1_mean = acts.ln1_mean + l * B * T;
float* l_ln1_rstd = acts.ln1_rstd + l * B * T;
float* l_qkv = acts.qkv + l * B * T * 3*C;
float* l_atty = acts.atty + l * B * T * C;
float* l_preatt = acts.preatt + l * B * NH * T * T;
float* l_att = acts.att + l * B * NH * T * T;
float* l_attproj = acts.attproj + l * B * T * C;
float* l_residual2 = acts.residual2 + l * B * T * C;
float* l_ln2 = acts.ln2 + l * B * T * C;
float* l_ln2_mean = acts.ln2_mean + l * B * T;
float* l_ln2_rstd = acts.ln2_rstd + l * B * T;
float* l_fch = acts.fch + l * B * T * 4*C;
float* l_fch_gelu = acts.fch_gelu + l * B * T * 4*C;
float* l_fcproj = acts.fcproj + l * B * T * C;
float* l_residual3 = acts.residual3 + l * B * T * C;
// now do the forward pass
printf(" [Forward] : LayerNorm1\n");
layernorm_forward(ctx, l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C);
printf(" [Forward] : QKV Projection\n");
matmul_forward(ctx, l_qkv, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C);
printf(" [Forward] : Attention\n");
attention_forward(ctx, l_atty, l_preatt, l_att, l_qkv, B, T, C, NH);
printf(" [Forward] : Attention Projection\n");
matmul_forward(ctx, l_attproj, l_atty, l_attprojw, l_attprojb, B, T, C, C);
printf(" [Forward] : Residual1\n");
residual_forward(ctx, l_residual2, residual, l_attproj, B*T*C);
printf(" [Forward] : LayerNorm2\n");
layernorm_forward(ctx, l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C);
printf(" [Forward] : FF Up\n");
matmul_forward(ctx, l_fch, l_ln2, l_fcw, l_fcb, B, T, C, 4*C);
printf(" [Forward] : GELU\n");
gelu_forward(ctx, l_fch_gelu, l_fch, B*T*4*C);
printf(" [Forward] : FF Down\n");
matmul_forward(ctx, l_fcproj, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C);
printf(" [Forward] : Residual2\n");
residual_forward(ctx, l_residual3, l_residual2, l_fcproj, B*T*C);
}
residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3
layernorm_forward(ctx, acts.lnf, acts.lnf_mean, acts.lnf_rstd, residual, params.lnfw, params.lnfb, B, T, C);
matmul_forward(ctx, acts.logits, acts.lnf, params.wte, NULL, B, T, C, Vp);
softmax_forward(ctx, acts.probs, acts.logits, B, T, V, Vp);
printf("Crossentropy\n");
// also forward the cross-entropy loss function if we have the targets
if (targets != NULL) {
crossentropy_forward(ctx, model->acts.losses, model->acts.probs, targets, B, T, Vp);
// for convenience also evaluate the mean loss
float mean_loss = 0.0f;
for (int i=0; i<B*T; i++) { mean_loss += model->acts.losses[i]; }
mean_loss /= B*T;
model->mean_loss = mean_loss;
} else {
// if we don't have targets, we don't have a loss
model->mean_loss = -1.0f;
}
printf("Forward pass done\n");
}
void gpt2_zero_grad(GPT2 *model) {
if(model->grads_memory != NULL) { memset(model->grads_memory, 0, model->num_parameters * sizeof(float)); }
if(model->grads_acts_memory != NULL) { memset(model->grads_acts_memory, 0, model->num_activations * sizeof(float)); }
}
void gpt2_backward(Context& ctx, GPT2 *model) {
printf("Backward pass\n");
// double check we forwarded previously, with targets
if (model->mean_loss == -1.0f) {
printf("Error: must forward with targets before backward\n");
exit(1);
}
// lazily allocate the memory for gradients of the weights and activations, if needed
if (model->grads_memory == NULL) {
printf("Allocating %.2f MB for gradients\n", model->num_parameters * sizeof(float) / (1024.0f * 1024.0f));
model->grads_memory = malloc_and_point_parameters(&model->grads, model->param_sizes);
model->grads_acts_memory = malloc_and_point_activations(&model->grads_acts, model->act_sizes);
gpt2_zero_grad(model);
}
// convenience shortcuts (and size_t to help prevent int overflow)
size_t B = model->batch_size;
size_t T = model->seq_len;
size_t V = model->config.vocab_size;
size_t Vp = model->config.padded_vocab_size;
size_t L = model->config.num_layers;
size_t NH = model->config.num_heads;
size_t C = model->config.channels;
// backward pass: go in the reverse order of the forward pass, and call backward() functions
ParameterTensors params = model->params; // for brevity
ParameterTensors grads = model->grads;
ActivationTensors acts = model->acts;
ActivationTensors grads_acts = model->grads_acts;
// we kick off the chain rule by filling in dlosses with 1.0f/(B*T)
// technically this is a small, inline backward() pass of calculating
// total, final loss as the mean over all losses over all (B,T) positions in the batch
float dloss_mean = 1.0f / (B*T);
for (int i = 0; i < B*T; i++) { grads_acts.losses[i] = dloss_mean; }
crossentropy_softmax_backward(ctx, grads_acts.logits, grads_acts.losses, acts.probs, model->targets, B, T, V, Vp);
matmul_backward(ctx, grads_acts.lnf, grads.wte, NULL, grads_acts.logits, acts.lnf, params.wte, B, T, C, Vp);
float* residual = acts.residual3 + (L-1) * B * T * C; // last layer's residual
float* dresidual = grads_acts.residual3 + (L-1) * B * T * C; // write to last layer's residual
layernorm_backward(ctx, dresidual, grads.lnfw, grads.lnfb, grads_acts.lnf, residual, params.lnfw, acts.lnf_mean, acts.lnf_rstd, B, T, C);
for (int l = L-1; l >= 0; l--) {
printf("Backward Pass Layer %d\n", l);
residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C;
dresidual = l == 0 ? grads_acts.encoded : grads_acts.residual3 + (l-1) * B * T * C;
// get the pointers of the weights for this layer
float* l_ln1w = params.ln1w + l * C;
float* l_qkvw = params.qkvw + l * 3*C * C;
float* l_attprojw = params.attprojw + l * C * C;
float* l_ln2w = params.ln2w + l * C;
float* l_fcw = params.fcw + l * 4*C * C;
float* l_fcprojw = params.fcprojw + l * C * 4*C;
// get the pointers of the gradients of the weights for this layer
float* dl_ln1w = grads.ln1w + l * C;
float* dl_ln1b = grads.ln1b + l * C;
float* dl_qkvw = grads.qkvw + l * 3*C * C;
float* dl_qkvb = grads.qkvb + l * 3*C;
float* dl_attprojw = grads.attprojw + l * C * C;
float* dl_attprojb = grads.attprojb + l * C;
float* dl_ln2w = grads.ln2w + l * C;
float* dl_ln2b = grads.ln2b + l * C;
float* dl_fcw = grads.fcw + l * 4*C * C;
float* dl_fcb = grads.fcb + l * 4*C;
float* dl_fcprojw = grads.fcprojw + l * C * 4*C;
float* dl_fcprojb = grads.fcprojb + l * C;
// get the pointers of the activations for this layer
float* l_ln1 = acts.ln1 + l * B * T * C;
float* l_ln1_mean = acts.ln1_mean + l * B * T;
float* l_ln1_rstd = acts.ln1_rstd + l * B * T;
float* l_qkv = acts.qkv + l * B * T * 3*C;
float* l_atty = acts.atty + l * B * T * C;
float* l_att = acts.att + l * B * NH * T * T;
float* l_residual2 = acts.residual2 + l * B * T * C;
float* l_ln2 = acts.ln2 + l * B * T * C;
float* l_ln2_mean = acts.ln2_mean + l * B * T;
float* l_ln2_rstd = acts.ln2_rstd + l * B * T;
float* l_fch = acts.fch + l * B * T * 4*C;
float* l_fch_gelu = acts.fch_gelu + l * B * T * 4*C;
// get the pointers of the gradients of the activations for this layer
float* dl_ln1 = grads_acts.ln1 + l * B * T * C;
float* dl_qkv = grads_acts.qkv + l * B * T * 3*C;
float* dl_atty = grads_acts.atty + l * B * T * C;
float* dl_preatt = grads_acts.preatt + l * B * NH * T * T;
float* dl_att = grads_acts.att + l * B * NH * T * T;
float* dl_attproj = grads_acts.attproj + l * B * T * C;
float* dl_residual2 = grads_acts.residual2 + l * B * T * C;
float* dl_ln2 = grads_acts.ln2 + l * B * T * C;
float* dl_fch = grads_acts.fch + l * B * T * 4*C;
float* dl_fch_gelu = grads_acts.fch_gelu + l * B * T * 4*C;
float* dl_fcproj = grads_acts.fcproj + l * B * T * C;
float* dl_residual3 = grads_acts.residual3 + l * B * T * C;
// backprop this layer
printf(" [Backward] : Residual2\n");
residual_backward(ctx, dl_residual2, dl_fcproj, dl_residual3, B*T*C);
printf(" [Backward] : FF Down \n");
matmul_backward(ctx, dl_fch_gelu, dl_fcprojw, dl_fcprojb, dl_fcproj, l_fch_gelu, l_fcprojw, B, T, 4*C, C);
printf(" [Backward] : GELU\n");
gelu_backward(ctx, dl_fch, l_fch, dl_fch_gelu, B*T*4*C);
printf(" [Backward] : FF Up\n");
matmul_backward(ctx, dl_ln2, dl_fcw, dl_fcb, dl_fch, l_ln2, l_fcw, B, T, C, 4*C);
printf(" [Backward] : LayerNorm2\n");
layernorm_backward(ctx, dl_residual2, dl_ln2w, dl_ln2b, dl_ln2, l_residual2, l_ln2w, l_ln2_mean, l_ln2_rstd, B, T, C);
printf(" [Backward] : Residual1\n");
residual_backward(ctx, dresidual, dl_attproj, dl_residual2, B*T*C);
printf(" [Backward] : Attention Projection\n");
matmul_backward(ctx, dl_atty, dl_attprojw, dl_attprojb, dl_attproj, l_atty, l_attprojw, B, T, C, C);
printf(" [Backward] : Attention\n");
attention_backward(ctx, dl_qkv, dl_preatt, dl_att, dl_atty, l_qkv, l_att, B, T, C, NH);
printf(" [Backward] : QKV Projection\n");
matmul_backward(ctx, dl_ln1, dl_qkvw, dl_qkvb, dl_qkv, l_ln1, l_qkvw, B, T, C, 3*C);
printf(" [Backward] : LayerNorm1\n");
layernorm_backward(ctx, dresidual, dl_ln1w, dl_ln1b, dl_ln1, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C);
}
encoder_backward(ctx, grads.wte, grads.wpe, grads_acts.encoded, model->inputs, B, T, C);
}
void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, int t) {
// reference: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html
// lazily allocate the memory for m_memory and v_memory
if (model->m_memory == NULL) {
model->m_memory = (float*)calloc(model->num_parameters, sizeof(float));
model->v_memory = (float*)calloc(model->num_parameters, sizeof(float));
}
for (size_t i = 0; i < model->num_parameters; i++) {
float param = model->params_memory[i];
float grad = model->grads_memory[i];
// update the first moment (momentum)
float m = beta1 * model->m_memory[i] + (1.0f - beta1) * grad;
// update the second moment (RMSprop)
float v = beta2 * model->v_memory[i] + (1.0f - beta2) * grad * grad;
// bias-correct both moments
float m_hat = m / (1.0f - powf(beta1, t));
float v_hat = v / (1.0f - powf(beta2, t));
// update
model->m_memory[i] = m;
model->v_memory[i] = v;
model->params_memory[i] -= learning_rate * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * param);
}
}
void gpt2_free(GPT2 *model) {
free(model->params_memory);
free(model->grads_memory);
free(model->m_memory);
free(model->v_memory);
free(model->acts_memory);
free(model->grads_acts_memory);
free(model->inputs);
free(model->targets);
}
#ifndef TESTING
// if we are TESTING (see test_gpt2.c), we'll skip the int main below
// ----------------------------------------------------------------------------
// sampler
unsigned int random_u32(uint64_t *state) {
// xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
*state ^= *state >> 12;
*state ^= *state << 25;
*state ^= *state >> 27;
return (*state * 0x2545F4914F6CDD1Dull) >> 32;
}
float random_f32(uint64_t *state) { // random float32 in [0,1)
return (random_u32(state) >> 8) / 16777216.0f;
}
int sample_mult(float* probabilities, int n, float coin) {
// sample index from probabilities (they must sum to 1!)
// coin is a random number in [0, 1), usually from random_f32()
float cdf = 0.0f;
for (int i = 0; i < n; i++) {
cdf += probabilities[i];
if (coin < cdf) {
return i;
}
}
return n - 1; // in case of rounding errors
}
// ----------------------------------------------------------------------------
// main training loop
int main() {
setLogLevel(kWarn);
printf("Creating GPU context\n");
WGPURequiredLimits requiredLimits = LIMITS_BUFFER_SIZE_1GB;
gpu::Context ctx = gpu::createContext({}, {}, {
.requiredLimits = &requiredLimits
});
// gpu::Context ctx = gpu::createContext();
// build the GPT-2 model from a checkpoint
GPT2 model;
gpt2_build_from_checkpoint(ctx, &model, "gpt2_124M.bin");
// build the DataLoaders from tokens files. for now use tiny_shakespeare if available, else tiny_stories
const char* tiny_stories_train = "dev/data/tinystories/TinyStories_train.bin";
const char* tiny_stories_val = "dev/data/tinystories/TinyStories_val.bin";
const char* tiny_shakespeare_train = "dev/data/tinyshakespeare/tiny_shakespeare_train.bin";
const char* tiny_shakespeare_val = "dev/data/tinyshakespeare/tiny_shakespeare_val.bin";
const char* train_tokens = access(tiny_shakespeare_train, F_OK) != -1 ? tiny_shakespeare_train : tiny_stories_train;
const char* val_tokens = access(tiny_shakespeare_val, F_OK) != -1 ? tiny_shakespeare_val : tiny_stories_val;
constexpr int B = 4; // batch size 4 (i.e. 4 independent token sequences will be trained on)
constexpr int T = 64; // sequence length 64 (i.e. each sequence is 64 tokens long). must be <= maxT, which is 1024 for GPT-2
DataLoader train_loader, val_loader;
dataloader_init(&train_loader, train_tokens, B, T, 0, 1, 1);
dataloader_init(&val_loader, val_tokens, B, T, 0, 1, 0);
printf("train dataset num_batches: %zu\n", train_loader.num_tokens / (B*T));
printf("val dataset num_batches: %zu\n", val_loader.num_tokens / (B*T));
int val_num_batches = 5;
// build the Tokenizer
Tokenizer tokenizer;
tokenizer_init(&tokenizer, "gpt2_tokenizer.bin");
// some memory for generating samples from the model
uint64_t rng_state = 1337;
int* gen_tokens = (int*)mallocCheck(B * T * sizeof(int));
const int genT = 64; // number of steps of inference we will do
// train
struct timespec start, end;
printf("Starting training\n");
for (int step = 0; step <= 40; step++) {
printf("Step %d\n", step);
// once in a while estimate the validation loss
if (step % 10 == 0) {
float val_loss = 0.0f;
dataloader_reset(&val_loader);
for (int i = 0; i < val_num_batches; i++) {
dataloader_next_batch(&val_loader);
gpt2_forward(ctx, &model, val_loader.inputs, val_loader.targets, B, T);
val_loss += model.mean_loss;
}
val_loss /= val_num_batches;
printf("val loss %f\n", val_loss);
}
// once in a while do model inference to print generated text
if (step > 0 && step % 20 == 0) {
// fill up gen_tokens with the GPT2_EOT, which kicks off the generation
for(int i = 0; i < B * T; ++i) {
gen_tokens[i] = tokenizer.eot_token;
}
// now sample from the model autoregressively
printf("generating:\n---\n");
for (int t = 1; t < genT; t++) {
// note that inference is very wasteful here because for each token
// we re-calculate the forward pass for all of (B,T) positions from scratch
// but the inference here is just for sanity checking anyway
// and we can maybe optimize a bit more later, with careful tests
gpt2_forward(ctx, &model, gen_tokens, NULL, B, T);
// furthermore, below we're only using b=0 (i.e. the first row) of all B rows
// we're in principle running B "inference streams" in parallel here
// but only using position 0
// get the Vp-dimensional vector probs[0, t-1, :]
float* probs = model.acts.probs + (t-1) * model.config.padded_vocab_size;
float coin = random_f32(&rng_state);
// note we're only sampling from the first V elements, ignoring padding
// (the probabilities in the padded region should be zero anyway)
int next_token = sample_mult(probs, model.config.vocab_size, coin);
gen_tokens[t] = next_token;
// print the generated token, either using the Tokenizer or a fallback
if (tokenizer.init_ok) {
const char* token_str = tokenizer_decode(&tokenizer, next_token);
safe_printf(token_str);
} else {
// fall back to printing the token id
printf("%d ", next_token);
}
fflush(stdout);
}
printf("\n---\n");
}
// do a training step
clock_gettime(CLOCK_MONOTONIC, &start);
dataloader_next_batch(&train_loader);
gpt2_forward(ctx, &model, train_loader.inputs, train_loader.targets, B, T);
gpt2_zero_grad(&model);
gpt2_backward(ctx, &model);
gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, step+1);
clock_gettime(CLOCK_MONOTONIC, &end);
double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9;
printf("step %d: train loss %f (took %f ms)\n", step, model.mean_loss, time_elapsed_s * 1000);
}
// free
dataloader_free(&train_loader);
dataloader_free(&val_loader);
tokenizer_free(&tokenizer);
gpt2_free(&model);
free(gen_tokens);
return 0;
}
#endif