Is sparse tensor multiplication implemented in TensorFlow?
Recently, tf.sparse_tensor_dense_matmul(...)
was added that allows multiplying a sparse matrix by a dense matrix.
https://www.tensorflow.org/versions/r0.9/api_docs/python/sparse_ops.html#sparse_tensor_dense_matmul
https://github.com/tensorflow/tensorflow/issues/1241
General-purpose multiplication for tf.SparseTensor
is not currently implemented in TensorFlow. However, there are three partial solutions, and the right one to choose will depend on the characteristics of your data:
If you have a
tf.SparseTensor
and atf.Tensor
, you can usetf.sparse_tensor_dense_matmul()
to multiply them. This is more efficient than the next approach if one of the tensors is too large to fit in memory when densified: the documentation has more guidance about how to decide between these two methods. Note that it accepts atf.SparseTensor
as the first argument, so to solve your exact problem you will need to use theadjoint_a
andadjoint_b
arguments, and transpose the result.If you have two sparse tensors and need to multiply them, the simplest (if not the most performant) way is to convert them to dense and use
tf.matmul
:a = tf.SparseTensor(...) b = tf.SparseTensor(...) c = tf.matmul(tf.sparse_tensor_to_dense(a, 0.0), tf.sparse_tensor_to_dense(b, 0.0), a_is_sparse=True, b_is_sparse=True)
Note that the optional
a_is_sparse
andb_is_sparse
arguments mean that "a
(orb
) has a dense representation but a large number of its entries are zero", which triggers the use of a different multiplication algorithm.For the special case of sparse vector by (potentially large and sharded) dense matrix multiplication, and the values in the vector are 0 or 1, the
tf.nn.embedding_lookup
operator may be more appropriate. This tutorial discusses when you might use embeddings and how to invoke the operator in more detail.For the special case of sparse matrix by (potentially large and sharded) dense matrix,
tf.nn.embedding_lookup_sparse()
may be appropriate. This function accepts one or twotf.SparseTensor
objects, withsp_ids
representing the non-zero values, and the optionalsp_weights
representing their values (which otherwise default to one).