Skip to content
This repository was archived by the owner on Jan 11, 2023. It is now read-only.

Commit 578a506

Browse files
author
Erik Bernhardsson
authored
Merge pull request #17 from yatsukav/master
added dot metric index support
2 parents 0a9d5a3 + 6426a47 commit 578a506

7 files changed

Lines changed: 80 additions & 13 deletions

File tree

src/main/java/com/spotify/annoy/ANNIndex.java

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ public ANNIndex(final int dimension,
6464
final int blockSize) throws IOException {
6565
DIMENSION = dimension;
6666
INDEX_TYPE = indexType;
67-
INDEX_TYPE_OFFSET = (INDEX_TYPE == IndexType.ANGULAR) ? 4 : 8;
68-
K_NODE_HEADER_STYLE = (INDEX_TYPE == IndexType.ANGULAR) ? 12 : 16;
67+
INDEX_TYPE_OFFSET = INDEX_TYPE.getOffset();
68+
K_NODE_HEADER_STYLE = INDEX_TYPE.getkNodeHeaderStyle();
6969
// we can store up to MIN_LEAF_SIZE children in leaf nodes (we put
7070
// them where the separating plane normally goes)
7171
this.MIN_LEAF_SIZE = DIMENSION + 2;
@@ -151,6 +151,10 @@ private float getNodeBias(final long nodeOffset) { // euclidean-only
151151
return getFloatInAnnBuf(nodeOffset + 4);
152152
}
153153

154+
private float getDotFactor(final long nodeOffset) { // dot-only
155+
return getFloatInAnnBuf(nodeOffset + 12);
156+
}
157+
154158
public final float[] getItemVector(final int itemIndex) {
155159
return getNodeVector(itemIndex * NODE_SIZE);
156160
}
@@ -175,11 +179,19 @@ private static float euclideanDistance(final float[] u, final float[] v) {
175179
return norm(diff);
176180
}
177181

178-
public static float cosineMargin(final float[] u, final float[] v) {
182+
public static float dot(final float[] u, final float[] v) {
179183
double d = 0;
180184
for (int i = 0; i < u.length; i++)
181185
d += u[i] * v[i];
182-
return (float) (d / (norm(u) * norm(v)));
186+
return (float) d;
187+
}
188+
189+
public static float cosineMargin(final float[] u, final float[] v) {
190+
return dot(u, v) / (norm(u) * norm(v));
191+
}
192+
193+
public static float dotMargin(final float[] u, final float[] v, final float norm) {
194+
return dot(u, v) + norm * norm;
183195
}
184196

185197
public static float euclideanMargin(final float[] u,
@@ -271,9 +283,9 @@ public final List<Integer> getNearest(final float[] queryVector,
271283
nearestNeighbors.add(j);
272284
}
273285
} else {
274-
float margin = (INDEX_TYPE == IndexType.ANGULAR) ?
275-
cosineMargin(v, queryVector) :
276-
euclideanMargin(v, queryVector, getNodeBias(topNodeOffset));
286+
float margin = (INDEX_TYPE == IndexType.ANGULAR) ? cosineMargin(v, queryVector)
287+
: (INDEX_TYPE == IndexType.DOT) ? dotMargin(v, queryVector, getDotFactor(topNodeOffset))
288+
: euclideanMargin(v, queryVector, getNodeBias(topNodeOffset));
277289
long childrenMemOffset = topNodeOffset + INDEX_TYPE_OFFSET;
278290
long lChild = NODE_SIZE * getIntInAnnBuf(childrenMemOffset);
279291
long rChild = NODE_SIZE * getIntInAnnBuf(childrenMemOffset + 4);
@@ -286,11 +298,10 @@ public final List<Integer> getNearest(final float[] queryVector,
286298
for (int nn : nearestNeighbors) {
287299
float[] v = getItemVector(nn);
288300
if (!isZeroVec(v)) {
289-
sortedNNs.add(
290-
new PQEntry((INDEX_TYPE == IndexType.ANGULAR) ?
291-
cosineMargin(v, queryVector) :
292-
-euclideanDistance(v, queryVector),
293-
nn));
301+
float margin = (INDEX_TYPE == IndexType.ANGULAR) ? cosineMargin(v, queryVector)
302+
: (INDEX_TYPE == IndexType.DOT) ? dot(v, queryVector)
303+
: -euclideanDistance(v, queryVector);
304+
sortedNNs.add(new PQEntry(margin, nn));
294305
}
295306
}
296307
Collections.sort(sortedNNs);
@@ -317,6 +328,8 @@ public static void main(final String[] args) throws IOException {
317328
IndexType indexType = null; // 2
318329
if (args[2].toLowerCase().equals("angular"))
319330
indexType = IndexType.ANGULAR;
331+
else if (args[2].toLowerCase().equals("dot"))
332+
indexType = IndexType.DOT;
320333
else if (args[2].toLowerCase().equals("euclidean"))
321334
indexType = IndexType.EUCLIDEAN;
322335
else throw new RuntimeException("wrong index type specified");
Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,23 @@
11
package com.spotify.annoy;
22

33
public enum IndexType {
4-
ANGULAR, EUCLIDEAN
4+
ANGULAR(4, 12),
5+
EUCLIDEAN(8, 16),
6+
DOT(4, 16);
7+
8+
private final int offset;
9+
private final int kNodeHeaderStyle;
10+
11+
IndexType(int offset, int kNodeHeaderStyle) {
12+
this.offset = offset;
13+
this.kNodeHeaderStyle = kNodeHeaderStyle;
14+
}
15+
16+
public int getOffset() {
17+
return offset;
18+
}
19+
20+
public int getkNodeHeaderStyle() {
21+
return kNodeHeaderStyle;
22+
}
523
}

src/test/java/com/spotify/annoy/ANNIndexTest.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,16 @@ public void testAngular() throws IOException {
7070
testIndex(IndexType.ANGULAR, 0, false);
7171
}
7272

73+
@Test
74+
/**
75+
Make sure that the NNs retrieved by the Java version match the
76+
ones pre-computed by the C++ version of the Angular index
77+
using the default block size (for files up to 2GB).
78+
*/
79+
public void testDot() throws IOException {
80+
testIndex(IndexType.DOT, 0, false);
81+
}
82+
7383

7484
@Test
7585
/**
@@ -92,6 +102,17 @@ public void testAngularBlocks() throws IOException {
92102
testIndex(IndexType.ANGULAR, 1, false);
93103
}
94104

105+
@Test
106+
/**
107+
Make sure that the NNs retrieved by the Java version match the
108+
ones pre-computed by the C++ version of the Angular index
109+
simulating files larger than 2GB.
110+
*/
111+
public void testDotBlocks() throws IOException {
112+
testIndex(IndexType.DOT, 10, false);
113+
testIndex(IndexType.DOT, 1, false);
114+
}
115+
95116

96117
@Test
97118
/**

src/test/resources/makeindex.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,17 @@
33
from annoy import AnnoyIndex
44

55
a = AnnoyIndex(8, 'a')
6+
d = AnnoyIndex(8, 'd')
67
e = AnnoyIndex(8, 'e')
78
for n, l in enumerate(open('points.csv')):
89
x = [float(f) for f in l.strip().split(',')]
910
a.add_item(n, x)
11+
d.add_item(n, x)
1012
e.add_item(n, x)
1113

1214
a.build(-1)
1315
a.save('points.angular.annoy')
16+
d.build(-1)
17+
d.save('points.dot.annoy')
1418
e.build(-1)
1519
e.save('points.euclidean.annoy')
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
1443 1379,161,2218,3294,1343,1507,340,529,140,2100
2+
1240 3610,430,2995,10,2848,200,6,822,3673,2168
3+
818 430,2995,3610,10,200,429,822,3774,1250,3182
4+
1725 627,696,1081,440,1353,2779,2364,6,1088,429
5+
1290 1147,4066,3673,1343,1003,1566,340,3476,1353,2779
6+
2031 1450,1272,440,627,3920,886,3227,1396,1081,4117
7+
1117 430,2995,3610,200,10,822,429,3774,3776,2435
8+
1211 1195,1365,1493,3263,2246,1081,2826,2064,221,627
9+
1902 430,2995,3610,200,10,822,3774,429,1250,3776
10+
603 3294,2100,2218,1854,4066,1343,1379,140,161,3913
410 KB
Binary file not shown.

src/test/resources/retrieve.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ def do(indextype):
1212
print >> out, '%s\t%s' % (q_index, ','.join([str(n) for n in nns]))
1313

1414
do('angular')
15+
do('dot')
1516
do('euclidean')

0 commit comments

Comments
 (0)