Skip to content

[QUESTION] In MoE, how is the correctness of Sinkhorn guaranteed when enabling sequence parallelism? #1740

@IamArabbit

Description

@IamArabbit

In MoE, sinkhorn is allowed to be used in the router. However, for a tensor t and its splits t0, t1 under sequence parallelism of 2, the result of sinkhorn after topk (let's say k=1) can be different (see code).

For example, if we run the code at the bottom, we will get

---- split 0
tensor([1, 3, 2, 2])

---- split 1
tensor([0, 3, 1, 2])

---- full
tensor([1, 3, 2, 0, 0, 3, 1, 2])

It shows the topk(k=1) result indices is not always the same for each token. How does Megatron-LM deal with this issue?

import torch

def sinkhorn(cost: torch.Tensor, tol: float = 0.0001):
    """Sinkhorn based MoE routing function"""
    cost = torch.exp(cost)
    d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype)
    d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype)

    eps = 0.00000001
    error = 1e9
    d1_old = d1
    while error > tol:
        d0 = (1 / d0.size(0)) * 1 / (torch.sum(d1 * cost, 1) + eps)
        d1 = (1 / d1.size(0)) * 1 / (torch.sum(d0.unsqueeze(1) * cost, 0) + eps)
        error = torch.mean(torch.abs(d1_old - d1))
        d1_old = d1
    return d1 * cost * d0.unsqueeze(1)

def print_sinkhorn_and_topk(cost: torch.Tensor) -> torch.Tensor:
    res = sinkhorn(cost)
    print(torch.topk(res, k=1, dim=1)[1].squeeze(1))
    print()


torch.manual_seed(42)
sequence_len = 8
num_experts = 4
hidden = torch.rand((sequence_len, num_experts))

hidden_0, hidden_1 = torch.split(hidden, sequence_len//2, dim=0)

print("---- split 0")
print_sinkhorn_and_topk(hidden_0)
print("---- split 1")
print_sinkhorn_and_topk(hidden_1)
print("---- full")
print_sinkhorn_and_topk(hidden)

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions