-
Notifications
You must be signed in to change notification settings - Fork 3k
Description
Hi Megatron-LM team!
While going through the code in megatron/core/pipeline_parallel/schedules.py
, I noticed that between each forward and backward pass, the line total_num_tokens += num_tokens.item()
uses the item()
method.
Megatron-LM/megatron/core/pipeline_parallel/schedules.py
Lines 451 to 467 in 8ca9e57
with no_sync_func(): | |
for i in range(num_microbatches - 1): | |
output_tensor, num_tokens = forward_step( | |
forward_step_func, | |
data_iterator, | |
model, | |
num_microbatches, | |
input_tensor, | |
forward_data_store, | |
config, | |
collect_non_loss_data, | |
is_first_microbatch=check_first_val_step(first_val_step, forward_only, i == 0), | |
current_microbatch=i, | |
) | |
total_num_tokens += num_tokens.item() | |
if not forward_only: | |
backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) |
From my understanding, the item() method transfers data from the GPU device to the host, which could cause the CPU to block and wait for the GPU to finish its computation. This might have a negative impact on performance, as illustrated below.
To validate this, I removed the item() method and observed that the time cost associated with this operation was completely eliminated.
Could you clarify why item()
is used here?
Thanks for your time and insights!