Skip to content

Commit 5a58cd9

Browse files
authored
some people get errors with view instead of reshape
1 parent e7b6871 commit 5a58cd9

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

models/layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
131131
query, key, value = map(lambda t: einops.rearrange(t, 'B S H D -> B H S D'), (query, key, value)) # needed for scaled_dot_product_attention but not flash_attn_func
132132
attn_output = scaled_dot_product_attention(query=query, key=key, value=value, is_causal=self.causal)
133133
attn_output = einops.rearrange(attn_output, 'B H S D -> B S H D')
134-
attn_output = attn_output.view(batch_size, seq_len, self.output_size) # type: ignore
134+
attn_output = attn_output.reshape(batch_size, seq_len, self.output_size) # type: ignore
135135
return self.o_proj(attn_output)
136136

137137
class LinearSwish(nn.Module):

0 commit comments

Comments
 (0)