From 1496a642894f97d5e932a8647da0f7021ef0cc8f Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 22 Mar 2023 16:12:27 +0000 Subject: [PATCH] Fix on logging embedding norms data. Save steps now matches logging steps. --- src/hf_trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/hf_trainer.py b/src/hf_trainer.py index 5601c42..7cb31cd 100644 --- a/src/hf_trainer.py +++ b/src/hf_trainer.py @@ -44,7 +44,7 @@ def compute_metrics(eval_pred: PredictionOutput) -> Dict[str, float]: "R@1": recall_at_1.item(), "R@5": recall_at_5.item(), "MRR": mean_reciprocal_rank.item(), - "embedding_norms": embedding_norms, + "embedding_norms": [norm for norm in embedding_norms], "min_embedding_norm": embedding_norms.min().item(), } @@ -263,6 +263,7 @@ def get_trainer( logging_strategy="steps", logging_steps=log_every, save_strategy="steps", + save_steps=log_every, evaluation_strategy="steps", report_to="wandb", )