Computing cosine similarity between two tensors in Keras
There are a few things that are unclear from the Keras documentation that I think are crucial to understanding:
For each function in the keras documentation for Merge
, there is a lower case and upper case one defined i.e. add()
and Add()
.
On Github, farizrahman4u
outlines the differences:
Merge is a layer.
Merge takes layers as input
Merge is usually used with Sequential models
merge is a function.
merge takes tensors as input.
merge is a wrapper around Merge.
merge is used in Functional API
Using Merge:
left = Sequential()
left.add(...)
left.add(...)
right = Sequential()
right.add(...)
right.add(...)
model = Sequential()
model.add(Merge([left, right]))
model.add(...)
using merge:
a = Input((10,))
b = Dense(10)(a)
c = Dense(10)(a)
d = merge([b, c])
model = Model(a, d)
To answer your question, since Merge
has been deprecated, we have to define and build a layer ourselves for the cosine similarity
. In general this will involve using those lowercase functions, which we wrap within a Lambda
to create a layer that we can use within a model.
I found a solution here:
from keras import backend as K
def cosine_distance(vests):
x, y = vests
x = K.l2_normalize(x, axis=-1)
y = K.l2_normalize(y, axis=-1)
return -K.mean(x * y, axis=-1, keepdims=True)
def cos_dist_output_shape(shapes):
shape1, shape2 = shapes
return (shape1[0],1)
distance = Lambda(cosine_distance, output_shape=cos_dist_output_shape)([processed_a, processed_b])
Depending on your data, you may want to remove the L2 normalization. What is important to note about the solution is that it is built using the Keras function api e.g. K.mean()
- I think this is necessary when defining custom layer or even loss functions.
Hope I was clear, this was my first SO answer!
The Dot
layer in Keras now supports built-in Cosine similarity using the normalize = True
parameter.
From the Keras Docs:
keras.layers.Dot(axes, normalize=True)
normalize: Whether to L2-normalize samples along the dot product axis before taking the dot product. If set to True, then the output of the dot product is the cosine proximity between the two samples.
Source