Skip to content

Commit 4985930

Browse files
Update
1 parent 1d8e435 commit 4985930

1 file changed

Lines changed: 107 additions & 21 deletions

File tree

experimental/kernels/gpt2_webgpu_aot.cpp

Lines changed: 107 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ typedef struct {
272272
Tensor targets; // the target tokens for the current forward pass
273273
float mean_loss; // after a forward pass with targets, will be populated with the mean loss
274274
float* mean_loss_buffer;
275+
float* probs_buffer;
275276

276277
Tensor nullTensor;
277278

@@ -377,6 +378,7 @@ void gpt2_build_from_checkpoint(Context& ctx, GPT2 *model, const char* checkpoin
377378
model->mean_loss = -1.0f; // -1.0f will designate no loss
378379
// Allocate B * C buffer for mean loss
379380
model->mean_loss_buffer = (float*)mallocCheck(sizeof(float) * model->batch_size * model->seq_len);
381+
model->probs_buffer = (float*)mallocCheck(sizeof(float) * model->batch_size * model->seq_len * Vp);
380382

381383
printf("Model build complete\n");
382384

@@ -616,7 +618,8 @@ void gpt2_forward(Context& ctx, GPT2 *model, Tensor& inputs, Tensor& targets, si
616618

617619
printf("Crossentropy\n");
618620
// also forward the cross-entropy loss function if we have the targets
619-
// if (targets != NULL) {
621+
// When targets's shape is (1), it means we don't have targets
622+
if (targets.shape[0] != 1) {
620623
// crossentropy_forward(ctx, model->acts.losses, model->acts.probs, targets, B, T, Vp);
621624
{
622625
std::promise<void> promise;
@@ -627,13 +630,14 @@ void gpt2_forward(Context& ctx, GPT2 *model, Tensor& inputs, Tensor& targets, si
627630
// for convenience also evaluate the mean loss
628631
float mean_loss = 0.0f;
629632
//toCPU(ctx, model->acts_.data[22], model->acts.losses.data, model->act_sizes[22] * sizeof(float));
630-
for (int i=0; i<B*T; i++) { mean_loss += model->acts.losses.data[i]; }
633+
toCPU(ctx, model->acts.losses, model->mean_loss_buffer, B*T * sizeof(float));
634+
for (int i=0; i<B*T; i++) { mean_loss += model->mean_loss_buffer[i]; }
631635
mean_loss /= B*T;
632636
model->mean_loss = mean_loss;
633-
// } else {
634-
// // if we don't have targets, we don't have a loss
635-
// model->mean_loss = -1.0f;
636-
// }
637+
} else {
638+
// if we don't have targets, we don't have a loss
639+
model->mean_loss = -1.0f;
640+
}
637641
printf("Forward pass done\n");
638642
}
639643

@@ -654,8 +658,8 @@ void gpt2_backward(Context& ctx, GPT2 *model) {
654658
// lazily allocate the memory for gradients of the weights and activations, if needed
655659
if (model->grads_memory == NULL) {
656660
printf("Allocating %.2f MB for gradients\n", model->num_parameters * sizeof(float) / (1024.0f * 1024.0f));
657-
malloc_and_point_parameters(&model->grads, model->param_sizes);
658-
malloc_and_point_activations(&model->grads_acts, model->act_sizes);
661+
malloc_and_point_parameters(ctx, &model->grads, model->param_sizes);
662+
malloc_and_point_activations(ctx, &model->grads_acts, model->act_sizes);
659663
gpt2_zero_grad(model);
660664
}
661665

@@ -678,8 +682,9 @@ void gpt2_backward(Context& ctx, GPT2 *model) {
678682
// technically this is a small, inline backward() pass of calculating
679683
// total, final loss as the mean over all losses over all (B,T) positions in the batch
680684
float dloss_mean = 1.0f / (B*T);
681-
for (int i = 0; i < B*T; i++) { grads_acts.losses.data[i] = dloss_mean; }
682-
toGPU(ctx, grads_acts.losses.data, model->acts_.data[22]);
685+
for (int i = 0; i < B*T; i++) { model->mean_loss_buffer[i] = dloss_mean; }
686+
toGPU(ctx, model->mean_loss_buffer, model->acts.losses);
687+
//toGPU(ctx, grads_acts.losses.data, model->acts_.data[22]);
683688

684689
// crossentropy_softmax_backward(ctx, grads_acts.logits, grads_acts.losses, acts.probs, model->targets, B, T, V, Vp);
685690
{
@@ -794,11 +799,11 @@ void gpt2_backward(Context& ctx, GPT2 *model) {
794799
dispatchKernel(ctx, model->kernels.encoder_backward, promise);
795800
wait(ctx, future);
796801
}
797-
toCPU(ctx, model->params_.data[0], model->grads.wte.data, model->param_sizes[0] * sizeof(float));
798-
toCPU(ctx, model->params_.data[1], model->grads.wpe.data, model->param_sizes[1] * sizeof(float));
802+
// toCPU(ctx, model->params_.data[0], model->grads.wte.data, model->param_sizes[0] * sizeof(float));
803+
// toCPU(ctx, model->params_.data[1], model->grads.wpe.data, model->param_sizes[1] * sizeof(float));
799804
}
800805

801-
void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, int t) {
806+
void gpt2_update(Context& ctx, GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, int t) {
802807
// reference: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html
803808

804809
// lazily allocate the memory for m_memory and v_memory
@@ -807,6 +812,45 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo
807812
model->v_memory = (float*)calloc(model->num_parameters, sizeof(float));
808813
}
809814

815+
// Copy the parameters to the CPU
816+
float* iter = model->params_memory;
817+
toCPU(ctx, model->params.wte, iter, model->param_sizes[0] * sizeof(float));
818+
iter += model->param_sizes[0];
819+
toCPU(ctx, model->params.wpe, iter, model->param_sizes[1] * sizeof(float));
820+
iter += model->param_sizes[1];
821+
size_t L = model->config.num_layers;
822+
for (int l = 0; l < L; l++) {
823+
toCPU(ctx, model->params.ln1w[l], iter, model->param_sizes[2]/L * sizeof(float));
824+
iter += model->param_sizes[2]/L;
825+
toCPU(ctx, model->params.ln1b[l], iter, model->param_sizes[3]/L * sizeof(float));
826+
iter += model->param_sizes[3]/L;
827+
toCPU(ctx, model->params.qkvw[l], iter, model->param_sizes[4]/L * sizeof(float));
828+
iter += model->param_sizes[4]/L;
829+
toCPU(ctx, model->params.qkvb[l], iter, model->param_sizes[5]/L * sizeof(float));
830+
iter += model->param_sizes[5]/L;
831+
toCPU(ctx, model->params.attprojw[l], iter, model->param_sizes[6]/L * sizeof(float));
832+
iter += model->param_sizes[6]/L;
833+
toCPU(ctx, model->params.attprojb[l], iter, model->param_sizes[7]/L * sizeof(float));
834+
iter += model->param_sizes[7]/L;
835+
toCPU(ctx, model->params.ln2w[l], iter, model->param_sizes[8]/L * sizeof(float));
836+
iter += model->param_sizes[8]/L;
837+
toCPU(ctx, model->params.ln2b[l], iter, model->param_sizes[9]/L * sizeof(float));
838+
iter += model->param_sizes[9]/L;
839+
toCPU(ctx, model->params.fcw[l], iter, model->param_sizes[10]/L * sizeof(float));
840+
iter += model->param_sizes[10]/L;
841+
toCPU(ctx, model->params.fcb[l], iter, model->param_sizes[11]/L * sizeof(float));
842+
iter += model->param_sizes[11]/L;
843+
toCPU(ctx, model->params.fcprojw[l], iter, model->param_sizes[12]/L * sizeof(float));
844+
iter += model->param_sizes[12]/L;
845+
toCPU(ctx, model->params.fcprojb[l], iter, model->param_sizes[13]/L * sizeof(float));
846+
iter += model->param_sizes[13]/L;
847+
}
848+
toCPU(ctx, model->params.lnfw, iter, model->param_sizes[14] * sizeof(float));
849+
iter += model->param_sizes[14];
850+
toCPU(ctx, model->params.lnfb, iter, model->param_sizes[15] * sizeof(float));
851+
iter += model->param_sizes[15];
852+
853+
810854
for (size_t i = 0; i < model->num_parameters; i++) {
811855
float param = model->params_memory[i];
812856
float grad = model->grads_memory[i];
@@ -824,8 +868,43 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo
824868
model->v_memory[i] = v;
825869
model->params_memory[i] -= learning_rate * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * param);
826870
}
827-
toGPU(ctx, model->params_memory, model->params_.data[0]);
828-
toGPU(ctx, model->params_memory + model->param_sizes[0], model->params_.data[1]);
871+
// toGPU(ctx, model->params_memory, model->params_.data[0]);
872+
// toGPU(ctx, model->params_memory + model->param_sizes[0], model->params_.data[1]);
873+
iter = model->params_memory;
874+
toGPU(ctx, iter, model->params.wte);
875+
iter += model->param_sizes[0];
876+
toGPU(ctx, iter, model->params.wpe);
877+
iter += model->param_sizes[1];
878+
for (int l = 0; l < L; l++) {
879+
toGPU(ctx, iter, model->params.ln1w[l]);
880+
iter += model->param_sizes[2]/L;
881+
toGPU(ctx, iter, model->params.ln1b[l]);
882+
iter += model->param_sizes[3]/L;
883+
toGPU(ctx, iter, model->params.qkvw[l]);
884+
iter += model->param_sizes[4]/L;
885+
toGPU(ctx, iter, model->params.qkvb[l]);
886+
iter += model->param_sizes[5]/L;
887+
toGPU(ctx, iter, model->params.attprojw[l]);
888+
iter += model->param_sizes[6]/L;
889+
toGPU(ctx, iter, model->params.attprojb[l]);
890+
iter += model->param_sizes[7]/L;
891+
toGPU(ctx, iter, model->params.ln2w[l]);
892+
iter += model->param_sizes[8]/L;
893+
toGPU(ctx, iter, model->params.ln2b[l]);
894+
iter += model->param_sizes[9]/L;
895+
toGPU(ctx, iter, model->params.fcw[l]);
896+
iter += model->param_sizes[10]/L;
897+
toGPU(ctx, iter, model->params.fcb[l]);
898+
iter += model->param_sizes[11]/L;
899+
toGPU(ctx, iter, model->params.fcprojw[l]);
900+
iter += model->param_sizes[12]/L;
901+
toGPU(ctx, iter, model->params.fcprojb[l]);
902+
iter += model->param_sizes[13]/L;
903+
}
904+
toGPU(ctx, iter, model->params.lnfw);
905+
iter += model->param_sizes[14];
906+
toGPU(ctx, iter, model->params.lnfb);
907+
iter += model->param_sizes[15];
829908
}
830909

831910
void gpt2_free(GPT2 *model) {
@@ -915,6 +994,7 @@ int main() {
915994
Tensor inputs = createTensor(ctx, Shape{B, T}, ki32);
916995
Tensor targets = createTensor(ctx, Shape{B, T}, ki32);
917996
Tensor gen_tokens = createTensor(ctx, Shape{B, T}, ki32);
997+
int* gen_tokens_cpu = (int*)mallocCheck(B * T * sizeof(int));
918998
printf("Starting training\n");
919999
for (int step = 0; step <= 40; step++) {
9201000
printf("Step %d\n", step);
@@ -937,7 +1017,10 @@ int main() {
9371017
// once in a while do model inference to print generated text
9381018
if (step > 0 && step % 20 == 0) {
9391019
// fill up gen_tokens with the GPT2_EOT, which kicks off the generation
940-
toGPU(ctx, tokenizer.eot_token, gen_tokens);
1020+
for(int i = 0; i < B * T; ++i) {
1021+
gen_tokens_cpu[i] = tokenizer.eot_token;
1022+
}
1023+
toGPU(ctx, gen_tokens_cpu, gen_tokens);
9411024
// now sample from the model autoregressively
9421025
printf("generating:\n---\n");
9431026
for (int t = 1; t < genT; t++) {
@@ -950,14 +1033,15 @@ int main() {
9501033
// we're in principle running B "inference streams" in parallel here
9511034
// but only using position 0
9521035
// get the Vp-dimensional vector probs[0, t-1, :]
953-
float* probs = model.acts.probs.data + (t-1) * model.config.padded_vocab_size;
954-
toCPU(ctx, model.acts_.data[21], probs, (t-1) * model.config.padded_vocab_size * sizeof(float));
1036+
toCPU(ctx, model.acts.probs, model.probs_buffer, B * T * model.config.padded_vocab_size * sizeof(float));
1037+
float* probs = model.probs_buffer + (t-1) * model.config.padded_vocab_size;
9551038

9561039
float coin = random_f32(&rng_state);
9571040
// note we're only sampling from the first V elements, ignoring padding
9581041
// (the probabilities in the padded region should be zero anyway)
9591042
int next_token = sample_mult(probs, model.config.vocab_size, coin);
960-
gen_tokens[t] = next_token;
1043+
gen_tokens_cpu[t] = next_token;
1044+
toGPU(ctx, gen_tokens_cpu, gen_tokens);
9611045
// print the generated token, either using the Tokenizer or a fallback
9621046
if (tokenizer.init_ok) {
9631047
const char* token_str = tokenizer_decode(&tokenizer, next_token);
@@ -974,10 +1058,12 @@ int main() {
9741058
// do a training step
9751059
clock_gettime(CLOCK_MONOTONIC, &start);
9761060
dataloader_next_batch(&train_loader);
977-
gpt2_forward(ctx, &model, train_loader.inputs, train_loader.targets, B, T);
1061+
toGPU(ctx, train_loader.inputs, inputs);
1062+
toGPU(ctx, train_loader.targets, targets);
1063+
gpt2_forward(ctx, &model, inputs, targets, B, T);
9781064
gpt2_zero_grad(&model);
9791065
gpt2_backward(ctx, &model);
980-
gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, step+1);
1066+
gpt2_update(ctx, &model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, step+1);
9811067
clock_gettime(CLOCK_MONOTONIC, &end);
9821068
double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9;
9831069
printf("step %d: train loss %f (took %f ms)\n", step, model.mean_loss, time_elapsed_s * 1000);

0 commit comments

Comments
 (0)