Skip to content

Commit 1496a64

Browse files
committed
Fix on logging embedding norms data. Save steps now matches logging steps.
1 parent 102f0e6 commit 1496a64

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

src/hf_trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def compute_metrics(eval_pred: PredictionOutput) -> Dict[str, float]:
4444
"R@1": recall_at_1.item(),
4545
"R@5": recall_at_5.item(),
4646
"MRR": mean_reciprocal_rank.item(),
47-
"embedding_norms": embedding_norms,
47+
"embedding_norms": [norm for norm in embedding_norms],
4848
"min_embedding_norm": embedding_norms.min().item(),
4949
}
5050

@@ -263,6 +263,7 @@ def get_trainer(
263263
logging_strategy="steps",
264264
logging_steps=log_every,
265265
save_strategy="steps",
266+
save_steps=log_every,
266267
evaluation_strategy="steps",
267268
report_to="wandb",
268269
)

0 commit comments

Comments
 (0)