@@ -33,16 +33,21 @@ void VLogGraphDebugString(Graph* g) {
33
33
34
34
class EmbeddingForwardBackwardJointOptimizationPass : public GraphOptimizationPass {
35
35
public:
36
+ EmbeddingForwardBackwardJointOptimizationPass () : GraphOptimizationPass() {
37
+ tensorflow::ReadBoolFromEnvVar (" TF_EMBEDDING_FBJ_OPT" ,
38
+ /* default_val=*/ false , &embedding_fbj_opt_);
39
+ if (!embedding_fbj_opt_) {
40
+ VLOG (2 ) << " Graph Optimization Pass TF_EMBEDDING_FBJ_OPT is off." ;
41
+ } else {
42
+ VLOG (2 ) << " Graph Optimization Pass TF_EMBEDDING_FBJ_OPT is on." ;
43
+ }
44
+ }
45
+
36
46
Status Run (const GraphOptimizationPassOptions& options) override {
37
- bool embedding_fbj_opt = false ;
38
- TF_CHECK_OK (
39
- tensorflow::ReadBoolFromEnvVar (" TF_EMBEDDING_FBJ_OPT" ,
40
- /* default_val=*/ false , &embedding_fbj_opt));
41
- if (!embedding_fbj_opt) {
42
- LOG (INFO) << " TF_EMBEDDING_FBJ_OPT off." ;
47
+ if (!embedding_fbj_opt_) {
43
48
return Status::OK ();
44
49
}
45
- LOG (INFO) << " TF_EMBEDDING_FBJ_OPT on. " ;
50
+
46
51
if (options.graph == nullptr ) {
47
52
// TODO(apassos) returning OK feels weird here as we can't do anything
48
53
// without a graph, but some tests require this.
@@ -198,7 +203,10 @@ class EmbeddingForwardBackwardJointOptimizationPass : public GraphOptimizationPa
198
203
return Status::OK ();
199
204
}
200
205
206
+ private:
207
+ bool embedding_fbj_opt_ = false ;
201
208
};
209
+
202
210
REGISTER_OPTIMIZATION (OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 0 ,
203
211
EmbeddingForwardBackwardJointOptimizationPass);
204
212
0 commit comments