@@ -7705,6 +7705,41 @@ def testEmbeddingVariableForSharedEmbeddingColumnsMultiCol(self):
7705
7705
for j in range (3 ):
7706
7706
self .assertAlmostEqual (emb_r [i ][j ], emb_right [i ][j ])
7707
7707
7708
+ def testEmbeddingVariableForSharedPartitionedEmbeddingColumnsMultiCol (self ):
7709
+ columns_list = []
7710
+ columns_list .append (fc .categorical_column_with_embedding ("col_emb" , dtype = dtypes .string ))
7711
+ columns_list .append (fc .categorical_column_with_embedding ("col_emb2" , dtype = dtypes .string ))
7712
+ W = fc .shared_embedding_columns (columns_list ,
7713
+ dimension = 3 ,
7714
+ initializer = init_ops .ones_initializer (dtypes .float32 ),
7715
+ shared_embedding_collection_name = "xxxxx_shared" )
7716
+
7717
+ ids = {}
7718
+ ids ["col_emb" ] = sparse_tensor .SparseTensor (indices = [[0 ,0 ],[1 ,0 ],[2 ,0 ],[3 ,0 ],[4 ,0 ]], values = ["aaaa" ,"bbbbb" ,"ccc" ,"4nn" ,"5b" ], dense_shape = [5 , 5 ])
7719
+ ids ["col_emb2" ] = sparse_tensor .SparseTensor (indices = [[0 ,0 ],[1 ,0 ],[2 ,0 ],[3 ,0 ],[4 ,0 ]], values = ["aaaa" ,"bbbbb" ,"ccc" ,"4nn" ,"5b" ], dense_shape = [5 , 5 ])
7720
+ with variable_scope .variable_scope ("scope" ,partitioner = partitioned_variables .fixed_size_partitioner (4 )):
7721
+ emb = fc_old .input_layer (ids , W )
7722
+ fun = math_ops .multiply (emb , 2.0 , name = 'multiply' )
7723
+ loss = math_ops .reduce_sum (fun , name = 'reduce_sum' )
7724
+ opt = ftrl .FtrlOptimizer (0.1 , l1_regularization_strength = 2.0 , l2_regularization_strength = 0.00001 )
7725
+ g_v = opt .compute_gradients (loss )
7726
+ train_op = opt .apply_gradients (g_v )
7727
+ init = variables_lib .global_variables_initializer ()
7728
+
7729
+ with self .test_session () as sess :
7730
+ sess .run (init )
7731
+ sess .run ([emb , train_op ,loss ])
7732
+ sess .run ([emb , train_op ,loss ])
7733
+ emb_r , _ , _ = sess .run ([emb , train_op ,loss ])
7734
+ emb_right = [[0.7221214 , 0.7221214 , 0.7221214 ],
7735
+ [0.7221214 , 0.7221214 , 0.7221214 ],
7736
+ [0.7221214 , 0.7221214 , 0.7221214 ],
7737
+ [0.7221214 , 0.7221214 , 0.7221214 ],
7738
+ [0.7221214 , 0.7221214 , 0.7221214 ]]
7739
+ for i in range (5 ):
7740
+ for j in range (3 ):
7741
+ self .assertAlmostEqual (emb_r [i ][j ], emb_right [i ][j ])
7742
+
7708
7743
@test_util .run_deprecated_v1
7709
7744
def testEmbeddingVariableForSharedEmbeddingColumnsWithPartitionNum (self ):
7710
7745
columns_list = []
0 commit comments