-
Notifications
You must be signed in to change notification settings - Fork 3k
Open
Labels
questionFurther information is requestedFurther information is requested
Description
In MoE, sinkhorn is allowed to be used in the router. However, for a tensor t
and its splits t0
, t1
under sequence parallelism of 2, the result of sinkhorn after topk (let's say k=1) can be different (see code).
For example, if we run the code at the bottom, we will get
---- split 0
tensor([1, 3, 2, 2])
---- split 1
tensor([0, 3, 1, 2])
---- full
tensor([1, 3, 2, 0, 0, 3, 1, 2])
It shows the topk(k=1) result indices is not always the same for each token. How does Megatron-LM deal with this issue?
import torch
def sinkhorn(cost: torch.Tensor, tol: float = 0.0001):
"""Sinkhorn based MoE routing function"""
cost = torch.exp(cost)
d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype)
d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype)
eps = 0.00000001
error = 1e9
d1_old = d1
while error > tol:
d0 = (1 / d0.size(0)) * 1 / (torch.sum(d1 * cost, 1) + eps)
d1 = (1 / d1.size(0)) * 1 / (torch.sum(d0.unsqueeze(1) * cost, 0) + eps)
error = torch.mean(torch.abs(d1_old - d1))
d1_old = d1
return d1 * cost * d0.unsqueeze(1)
def print_sinkhorn_and_topk(cost: torch.Tensor) -> torch.Tensor:
res = sinkhorn(cost)
print(torch.topk(res, k=1, dim=1)[1].squeeze(1))
print()
torch.manual_seed(42)
sequence_len = 8
num_experts = 4
hidden = torch.rand((sequence_len, num_experts))
hidden_0, hidden_1 = torch.split(hidden, sequence_len//2, dim=0)
print("---- split 0")
print_sinkhorn_and_topk(hidden_0)
print("---- split 1")
print_sinkhorn_and_topk(hidden_1)
print("---- full")
print_sinkhorn_and_topk(hidden)
Metadata
Metadata
Assignees
Labels
questionFurther information is requestedFurther information is requested