Skip to content

Commit ef98b2c

Browse files
authored
[TXL/PyT] update: (NVIDIA#989)
* changed API calls to torch.einsum * added export OMP_NUM_THREADS=1 to all launcher scripts * additional runtime checks to ensure that launch configuration is valid
1 parent 706ef49 commit ef98b2c

16 files changed

Lines changed: 61 additions & 27 deletions

PyTorch/LanguageModeling/Transformer-XL/README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1113,7 +1113,11 @@ perplexity on the test dataset.
11131113

11141114
## Performance
11151115

1116-
The performance measurements in this document were conducted at the time of publication and may not reflect the performance achieved from NVIDIA’s latest software release. For the most up-to-date performance measurements, go to [NVIDIA Data Center Deep Learning Product Performance](https://developer.nvidia.com/deep-learning-performance-training-inference).
1116+
The performance measurements in this document were conducted at the time of
1117+
publication and may not reflect the performance achieved from NVIDIA’s latest
1118+
software release. For the most up-to-date performance measurements, go to
1119+
[NVIDIA Data Center Deep Learning Product
1120+
Performance](https://developer.nvidia.com/deep-learning-performance-training-inference).
11171121

11181122
### Benchmarking
11191123

PyTorch/LanguageModeling/Transformer-XL/pytorch/inference/mem_transformer_jit.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def forward(self, h, attn_mask=None, mems=None):
122122
head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)
123123

124124
# [bsz x n_head x qlen x klen]
125-
attn_score = torch.einsum('ibnd,jbnd->bnij', (head_q, head_k))
125+
attn_score = torch.einsum('ibnd,jbnd->bnij', head_q, head_k)
126126
attn_score.mul_(self.scale)
127127
if attn_mask is not None:
128128
if attn_mask.dim() == 2:
@@ -135,7 +135,7 @@ def forward(self, h, attn_mask=None, mems=None):
135135
attn_prob = self.dropatt(attn_prob)
136136

137137
# [bsz x n_head x qlen x klen] * [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
138-
attn_vec = torch.einsum('bnij,jbnd->ibnd', (attn_prob, head_v))
138+
attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, head_v)
139139
attn_vec = attn_vec.contiguous().view(
140140
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
141141

@@ -262,13 +262,13 @@ def forward(self, w, r, r_w_bias, r_r_bias, attn_mask,
262262

263263
# compute attention score
264264
rw_head_q = w_head_q + r_w_bias # qlen x bsz x n_head x d_head
265-
# AC = torch.einsum('ibnd,jbnd->bnij', (rw_head_q, w_head_k)) # bsz x n_head x qlen x klen
265+
# AC = torch.einsum('ibnd,jbnd->bnij', rw_head_q, w_head_k) # bsz x n_head x qlen x klen
266266
rw_head_q = rw_head_q.view(qlen, bsz * self.n_head, self.d_head).permute(1, 0, 2)
267267
w_head_k = w_head_k.reshape(klen, bsz * self.n_head, self.d_head).permute(1, 2, 0)
268268
AC = torch.bmm(rw_head_q, w_head_k).view(bsz, self.n_head, qlen, klen)
269269

270270
rr_head_q = w_head_q + r_r_bias
271-
# BD = torch.einsum('ibnd,jnd->bnij', (rr_head_q, r_head_k)) # bsz x n_head x qlen x klen
271+
# BD = torch.einsum('ibnd,jnd->bnij', rr_head_q, r_head_k) # bsz x n_head x qlen x klen
272272
rr_head_q = rr_head_q.permute(2, 1, 0, 3).reshape(self.n_head, bsz * qlen, self.d_head)
273273
r_head_k = r_head_k.permute(1, 2, 0).view(self.n_head, self.d_head, klen)
274274
BD = torch.bmm(rr_head_q, r_head_k).view(self.n_head, bsz, qlen, klen).permute(1, 0, 2, 3)
@@ -290,7 +290,7 @@ def forward(self, w, r, r_w_bias, r_r_bias, attn_mask,
290290
attn_prob = self.dropatt(attn_prob)
291291

292292
# compute attention vector
293-
# attn_vec = torch.einsum('bnij,jbnd->ibnd', (attn_prob, w_head_v))
293+
# attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, w_head_v)
294294
attn_prob = attn_prob.view(bsz * self.n_head, qlen, klen)
295295
w_head_v = w_head_v.permute(1, 2, 0, 3).reshape(bsz * self.n_head, klen, self.d_head)
296296
attn_vec = torch.bmm(attn_prob, w_head_v).permute(1, 0, 2).view(qlen, bsz, self.n_head, self.d_head)
@@ -358,11 +358,11 @@ def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None):
358358
r_bias = r_bias.t()
359359

360360
# compute attention score
361-
rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head
361+
rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head
362362

363-
AC = torch.einsum('ibnd,jbnd->bnij', (rw_head_q, w_head_k)) # bsz x n_head x qlen x klen
364-
B_ = torch.einsum('ibnd,jnd->bnij', (w_head_q, r_emb)) # bsz x n_head x qlen x klen
365-
D_ = r_bias[None, :, None, :] # 1 x n_head x 1 x klen
363+
AC = torch.einsum('ibnd,jbnd->bnij', rw_head_q, w_head_k) # bsz x n_head x qlen x klen
364+
B_ = torch.einsum('ibnd,jnd->bnij', w_head_q, r_emb) # bsz x n_head x qlen x klen
365+
D_ = r_bias[None, :, None, :] # 1 x n_head x 1 x klen
366366
BD = self._rel_shift(B_ + D_)
367367

368368
# [bsz x qlen x klen x n_head]
@@ -381,7 +381,7 @@ def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None):
381381
attn_prob = self.dropatt(attn_prob)
382382

383383
# compute attention vector
384-
attn_vec = torch.einsum('bnij,jbnd->ibnd', (attn_prob, w_head_v))
384+
attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, w_head_v)
385385

386386
# [qlen x bsz x n_head x d_head]
387387
attn_vec = attn_vec.contiguous().view(

PyTorch/LanguageModeling/Transformer-XL/pytorch/inference/proj_adaptive_softmax_jit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def _compute_logit(self, hidden, weight, bias, proj: Optional[torch.Tensor]):
146146
if proj is None:
147147
logit = F.linear(hidden, weight, bias=bias)
148148
else:
149-
logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t()))
149+
logit = torch.einsum('bd,de,ev->bv', hidden, proj, weight.t())
150150
if bias is not None:
151151
logit = logit + bias
152152
return logit

PyTorch/LanguageModeling/Transformer-XL/pytorch/mem_transformer.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def forward(self, h, attn_mask=None, mems=None):
125125
head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)
126126

127127
# [bsz x n_head x qlen x klen]
128-
attn_score = torch.einsum('ibnd,jbnd->bnij', (head_q, head_k))
128+
attn_score = torch.einsum('ibnd,jbnd->bnij', head_q, head_k)
129129
attn_score.mul_(self.scale)
130130
if attn_mask is not None:
131131
if attn_mask.dim() == 2:
@@ -138,7 +138,7 @@ def forward(self, h, attn_mask=None, mems=None):
138138
attn_prob = self.dropatt(attn_prob)
139139

140140
# [bsz x n_head x qlen x klen] * [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
141-
attn_vec = torch.einsum('bnij,jbnd->ibnd', (attn_prob, head_v))
141+
attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, head_v)
142142
attn_vec = attn_vec.contiguous().view(
143143
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
144144

@@ -264,10 +264,10 @@ def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None):
264264

265265
# compute attention score
266266
rw_head_q = w_head_q + r_w_bias # qlen x bsz x n_head x d_head
267-
AC = torch.einsum('ibnd,jbnd->bnij', (rw_head_q, w_head_k)) # bsz x n_head x qlen x klen
267+
AC = torch.einsum('ibnd,jbnd->bnij', rw_head_q, w_head_k) # bsz x n_head x qlen x klen
268268

269269
rr_head_q = w_head_q + r_r_bias
270-
BD = torch.einsum('ibnd,jnd->bnij', (rr_head_q, r_head_k)) # bsz x n_head x qlen x klen
270+
BD = torch.einsum('ibnd,jnd->bnij', rr_head_q, r_head_k) # bsz x n_head x qlen x klen
271271
BD = self._rel_shift(BD)
272272

273273
# [bsz x n_head x qlen x klen]
@@ -285,7 +285,7 @@ def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None):
285285
attn_prob = self.dropatt(attn_prob)
286286

287287
# compute attention vector
288-
attn_vec = torch.einsum('bnij,jbnd->ibnd', (attn_prob, w_head_v))
288+
attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, w_head_v)
289289

290290
# [qlen x bsz x n_head x d_head]
291291
attn_vec = attn_vec.contiguous().view(
@@ -350,11 +350,11 @@ def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None):
350350
r_bias = r_bias.t()
351351

352352
# compute attention score
353-
rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head
353+
rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head
354354

355-
AC = torch.einsum('ibnd,jbnd->bnij', (rw_head_q, w_head_k)) # bsz x n_head x qlen x klen
356-
B_ = torch.einsum('ibnd,jnd->bnij', (w_head_q, r_emb)) # bsz x n_head x qlen x klen
357-
D_ = r_bias[None, :, None, :] # 1 x n_head x 1 x klen
355+
AC = torch.einsum('ibnd,jbnd->bnij', rw_head_q, w_head_k) # bsz x n_head x qlen x klen
356+
B_ = torch.einsum('ibnd,jnd->bnij', w_head_q, r_emb) # bsz x n_head x qlen x klen
357+
D_ = r_bias[None, :, None, :] # 1 x n_head x 1 x klen
358358
BD = self._rel_shift(B_ + D_)
359359

360360
# [bsz x qlen x klen x n_head]
@@ -372,7 +372,7 @@ def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None):
372372
attn_prob = self.dropatt(attn_prob)
373373

374374
# compute attention vector
375-
attn_vec = torch.einsum('bnij,jbnd->ibnd', (attn_prob, w_head_v))
375+
attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, w_head_v)
376376

377377
# [qlen x bsz x n_head x d_head]
378378
attn_vec = attn_vec.contiguous().view(

PyTorch/LanguageModeling/Transformer-XL/pytorch/run_enwik8_base.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#!/bin/bash
22

3+
export OMP_NUM_THREADS=1
4+
35
if [[ $1 == 'train' ]]; then
46
echo 'Run training...'
57
python train.py \

PyTorch/LanguageModeling/Transformer-XL/pytorch/run_enwik8_large.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#!/bin/bash
22

3+
export OMP_NUM_THREADS=1
4+
35
if [[ $1 == 'train' ]]; then
46
echo 'Run training...'
57
python train.py \

PyTorch/LanguageModeling/Transformer-XL/pytorch/run_lm1b_base.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#!/bin/bash
22

3+
export OMP_NUM_THREADS=1
4+
35
if [[ $1 == 'train' ]]; then
46
echo 'Run training...'
57
python train.py \

PyTorch/LanguageModeling/Transformer-XL/pytorch/run_lm1b_large.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#!/bin/bash
22

3+
export OMP_NUM_THREADS=1
4+
35
if [[ $1 == 'train' ]]; then
46
echo 'Run training...'
57
python train.py \

PyTorch/LanguageModeling/Transformer-XL/pytorch/run_multinode_wt103_large.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
export OMP_NUM_THREADS=1
18+
1719
if [[ $1 == 'train' ]] || [[ $1 == 'all' ]]; then
1820
echo 'Run training...'
1921
python train.py \

PyTorch/LanguageModeling/Transformer-XL/pytorch/run_text8_base.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#!/bin/bash
22

3+
export OMP_NUM_THREADS=1
4+
35
if [[ $1 == 'train' ]]; then
46
echo 'Run training...'
57
python train.py \

0 commit comments

Comments
 (0)