Skip to content

[QUESTION] Performance Impact of Using item() in total_num_tokens += num_tokens.item() in megatron/core/pipeline_parallel/schedules.py #1403

@wan-nan

Description

@wan-nan

Hi Megatron-LM team!

While going through the code in megatron/core/pipeline_parallel/schedules.py, I noticed that between each forward and backward pass, the line total_num_tokens += num_tokens.item() uses the item() method.

with no_sync_func():
for i in range(num_microbatches - 1):
output_tensor, num_tokens = forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
is_first_microbatch=check_first_val_step(first_val_step, forward_only, i == 0),
current_microbatch=i,
)
total_num_tokens += num_tokens.item()
if not forward_only:
backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)

From my understanding, the item() method transfers data from the GPU device to the host, which could cause the CPU to block and wait for the GPU to finish its computation. This might have a negative impact on performance, as illustrated below.

Image

To validate this, I removed the item() method and observed that the time cost associated with this operation was completely eliminated.

Image

Could you clarify why item() is used here?

Thanks for your time and insights!

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions