@@ -64,8 +64,8 @@ public ANNIndex(final int dimension,
6464 final int blockSize ) throws IOException {
6565 DIMENSION = dimension ;
6666 INDEX_TYPE = indexType ;
67- INDEX_TYPE_OFFSET = ( INDEX_TYPE == IndexType . ANGULAR ) ? 4 : 8 ;
68- K_NODE_HEADER_STYLE = ( INDEX_TYPE == IndexType . ANGULAR ) ? 12 : 16 ;
67+ INDEX_TYPE_OFFSET = INDEX_TYPE . getOffset () ;
68+ K_NODE_HEADER_STYLE = INDEX_TYPE . getkNodeHeaderStyle () ;
6969 // we can store up to MIN_LEAF_SIZE children in leaf nodes (we put
7070 // them where the separating plane normally goes)
7171 this .MIN_LEAF_SIZE = DIMENSION + 2 ;
@@ -151,6 +151,10 @@ private float getNodeBias(final long nodeOffset) { // euclidean-only
151151 return getFloatInAnnBuf (nodeOffset + 4 );
152152 }
153153
154+ private float getDotFactor (final long nodeOffset ) { // dot-only
155+ return getFloatInAnnBuf (nodeOffset + 12 );
156+ }
157+
154158 public final float [] getItemVector (final int itemIndex ) {
155159 return getNodeVector (itemIndex * NODE_SIZE );
156160 }
@@ -175,11 +179,19 @@ private static float euclideanDistance(final float[] u, final float[] v) {
175179 return norm (diff );
176180 }
177181
178- public static float cosineMargin (final float [] u , final float [] v ) {
182+ public static float dot (final float [] u , final float [] v ) {
179183 double d = 0 ;
180184 for (int i = 0 ; i < u .length ; i ++)
181185 d += u [i ] * v [i ];
182- return (float ) (d / (norm (u ) * norm (v )));
186+ return (float ) d ;
187+ }
188+
189+ public static float cosineMargin (final float [] u , final float [] v ) {
190+ return dot (u , v ) / (norm (u ) * norm (v ));
191+ }
192+
193+ public static float dotMargin (final float [] u , final float [] v , final float norm ) {
194+ return dot (u , v ) + norm * norm ;
183195 }
184196
185197 public static float euclideanMargin (final float [] u ,
@@ -271,9 +283,9 @@ public final List<Integer> getNearest(final float[] queryVector,
271283 nearestNeighbors .add (j );
272284 }
273285 } else {
274- float margin = (INDEX_TYPE == IndexType .ANGULAR ) ?
275- cosineMargin ( v , queryVector ) :
276- euclideanMargin (v , queryVector , getNodeBias (topNodeOffset ));
286+ float margin = (INDEX_TYPE == IndexType .ANGULAR ) ? cosineMargin ( v , queryVector )
287+ : ( INDEX_TYPE == IndexType . DOT ) ? dotMargin ( v , queryVector , getDotFactor ( topNodeOffset ))
288+ : euclideanMargin (v , queryVector , getNodeBias (topNodeOffset ));
277289 long childrenMemOffset = topNodeOffset + INDEX_TYPE_OFFSET ;
278290 long lChild = NODE_SIZE * getIntInAnnBuf (childrenMemOffset );
279291 long rChild = NODE_SIZE * getIntInAnnBuf (childrenMemOffset + 4 );
@@ -286,11 +298,10 @@ public final List<Integer> getNearest(final float[] queryVector,
286298 for (int nn : nearestNeighbors ) {
287299 float [] v = getItemVector (nn );
288300 if (!isZeroVec (v )) {
289- sortedNNs .add (
290- new PQEntry ((INDEX_TYPE == IndexType .ANGULAR ) ?
291- cosineMargin (v , queryVector ) :
292- -euclideanDistance (v , queryVector ),
293- nn ));
301+ float margin = (INDEX_TYPE == IndexType .ANGULAR ) ? cosineMargin (v , queryVector )
302+ : (INDEX_TYPE == IndexType .DOT ) ? dot (v , queryVector )
303+ : -euclideanDistance (v , queryVector );
304+ sortedNNs .add (new PQEntry (margin , nn ));
294305 }
295306 }
296307 Collections .sort (sortedNNs );
@@ -317,6 +328,8 @@ public static void main(final String[] args) throws IOException {
317328 IndexType indexType = null ; // 2
318329 if (args [2 ].toLowerCase ().equals ("angular" ))
319330 indexType = IndexType .ANGULAR ;
331+ else if (args [2 ].toLowerCase ().equals ("dot" ))
332+ indexType = IndexType .DOT ;
320333 else if (args [2 ].toLowerCase ().equals ("euclidean" ))
321334 indexType = IndexType .EUCLIDEAN ;
322335 else throw new RuntimeException ("wrong index type specified" );
0 commit comments