-
Notifications
You must be signed in to change notification settings - Fork 7
Description
Description of feature
As mentioned in the discussion around PR #110, there's more improvements to be made regarding the explainer functions. Some points that came up:
Multiple attributions for different classes from one gradient pass
The biggest one is getting attributions for multiple classes with one gradient pass, like we had a preliminary implementation of in #104. Not 100% sure if it'd work, but if it looked okay before, seems like it would?
Better transfers to/from GPU
There might also be improvements with doing more on the GPU.
- transferring only the entire output to cpu/numpy in
function_batch
, rather than each batch. That does mean you can't just write batches to a numpy array anymore, but need to write to i.e. a tf.Variable tensor or a list and concat with tf./torch.concat. - If that'd be implemented,
integrated_grad()
could also do the integration over steps and averaging over baselines on the GPU, so that the object to transfer to CPU goes from (n_baselines, n_steps, seq_len, 4) to (1, seq_len, 4) (usually 25, 26, ... so 650x smaller)
Batching over sequences
When calculating saliency maps (aka raw gradients), we batch over all sequences (since calculating explanations for one sequence only requires one gradient, easy), but for integrated_gradients()
we currently don't. Instead, there we loop over sequences and call function_batch()
separately for each sequence.
If you're calculating expected integrated gradients (num_baselines=25
and baseline_type='random'
), this wouldn't make a difference, since each sequence requires 25*26=650 explained sequences anyway, so it's already filling up all batches. However, if you're doing integrated gradients (num_baselines=1
and baseline_type='zeros'
), batching over sequences could lead to a speedup, since you normally get gradients for 26 sequences to explain one sequence, so you could boost that to fit many more in the default 128 batch size.
Problem with that is that you do use much much more memory when doing expected integrated gradients especially. As an example, for 500 sequences, we currently never exceed 499+650=1149 gradients in memory (499 'reduced' final explanations + 650 temp gradients to be reduced into the final 500th explanation). If we batch over all sequences, we'd have 500*650=325000 gradients in memory at the same time, which I'm not sure is desirable. To solve that, you could implement a CPU-side batch size or something, but it does get complicated fast.
Mutagenesis
I haven't touched this code at all, it probably could be cleaned up as well.
General code improvements
Also, there might just be bottlenecks I'm not seeing (I'm not a code speed expert by any means). Specifically on the torch side, I haven't looked at that at all, and it might do transfers to/from GPU differently than tensorflow. I also got rid of the original dataset-using batching function because it was always giving warnings in TensorFlow, but it might be a perfectly fine way to work in pytorch? Haven't tested that, but it's worth checking.