Skip to content

Commit feab52d

Browse files
authored
[Embedding] Fix SharedEmbeddingColumn with PartitionedEmbedingVariable shape validation error. (#948)
Signed-off-by: candy.dc <[email protected]>
1 parent 29d9b46 commit feab52d

File tree

3 files changed

+41
-0
lines changed

3 files changed

+41
-0
lines changed

tensorflow/python/feature_column/feature_column.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2675,6 +2675,9 @@ def create_embedding(self,
26752675
embedding_weights = shared_embedding_collection[0]
26762676
if isinstance(embedding_weights, kv_variable_ops.EmbeddingVariable):
26772677
embedding_shape = (self.dimension)
2678+
elif isinstance(embedding_weights, variables.PartitionedVariable):
2679+
if isinstance(embedding_weights._get_variable_list()[0], kv_variable_ops.EmbeddingVariable):
2680+
embedding_shape = (self.dimension)
26782681
if embedding_weights.get_shape() != embedding_shape:
26792682
raise ValueError(
26802683
'Shared embedding collection {} contains variable {} of '

tensorflow/python/feature_column/feature_column_v2_test.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7705,6 +7705,41 @@ def testEmbeddingVariableForSharedEmbeddingColumnsMultiCol(self):
77057705
for j in range(3):
77067706
self.assertAlmostEqual(emb_r[i][j], emb_right[i][j])
77077707

7708+
def testEmbeddingVariableForSharedPartitionedEmbeddingColumnsMultiCol(self):
7709+
columns_list=[]
7710+
columns_list.append(fc.categorical_column_with_embedding("col_emb", dtype=dtypes.string))
7711+
columns_list.append(fc.categorical_column_with_embedding("col_emb2", dtype=dtypes.string))
7712+
W = fc.shared_embedding_columns(columns_list,
7713+
dimension=3,
7714+
initializer=init_ops.ones_initializer(dtypes.float32),
7715+
shared_embedding_collection_name="xxxxx_shared")
7716+
7717+
ids={}
7718+
ids["col_emb"] = sparse_tensor.SparseTensor(indices=[[0,0],[1,0],[2,0],[3,0],[4,0]], values=["aaaa","bbbbb","ccc","4nn","5b"], dense_shape=[5, 5])
7719+
ids["col_emb2"] = sparse_tensor.SparseTensor(indices=[[0,0],[1,0],[2,0],[3,0],[4,0]], values=["aaaa","bbbbb","ccc","4nn","5b"], dense_shape=[5, 5])
7720+
with variable_scope.variable_scope("scope",partitioner=partitioned_variables.fixed_size_partitioner(4)):
7721+
emb = fc_old.input_layer(ids, W)
7722+
fun = math_ops.multiply(emb, 2.0, name='multiply')
7723+
loss = math_ops.reduce_sum(fun, name='reduce_sum')
7724+
opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001)
7725+
g_v = opt.compute_gradients(loss)
7726+
train_op = opt.apply_gradients(g_v)
7727+
init = variables_lib.global_variables_initializer()
7728+
7729+
with self.test_session() as sess:
7730+
sess.run(init)
7731+
sess.run([emb, train_op,loss])
7732+
sess.run([emb, train_op,loss])
7733+
emb_r, _, _ = sess.run([emb, train_op,loss])
7734+
emb_right = [[0.7221214, 0.7221214, 0.7221214],
7735+
[0.7221214, 0.7221214, 0.7221214],
7736+
[0.7221214, 0.7221214, 0.7221214],
7737+
[0.7221214, 0.7221214, 0.7221214],
7738+
[0.7221214, 0.7221214, 0.7221214]]
7739+
for i in range(5):
7740+
for j in range(3):
7741+
self.assertAlmostEqual(emb_r[i][j], emb_right[i][j])
7742+
77087743
@test_util.run_deprecated_v1
77097744
def testEmbeddingVariableForSharedEmbeddingColumnsWithPartitionNum(self):
77107745
columns_list=[]

tensorflow/python/ops/variables.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3100,6 +3100,9 @@ def __init__(self, name, shape, dtype, variable_list, partitions):
31003100

31013101
self._name = name
31023102
self._shape = shape
3103+
from tensorflow.python.ops import kv_variable_ops
3104+
if isinstance(self._variable_list[0], kv_variable_ops.EmbeddingVariable):
3105+
self._shape = shape[1:]
31033106
self._dtype = dtype
31043107
self._partitions = partitions
31053108
self._as_tensor = None

0 commit comments

Comments
 (0)