Skip to content

Commit

Permalink
Add support for scaling vectors for use with 8bit HNSW codec
Browse files Browse the repository at this point in the history
Add tests for WikiVectors/VectorDictionary
Add 8bit vector task files
  • Loading branch information
Michael Sokolov committed Aug 14, 2022
1 parent aaeca20 commit e33c95b
Show file tree
Hide file tree
Showing 19 changed files with 402 additions and 52 deletions.
78 changes: 74 additions & 4 deletions build.xml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
includes="WikiVectors.java,perf/VectorDictionary.java"
classpathref="build.classpath"
includeantruntime="false"/>

<!-- run a test; we don't have any framework, just call a test method -->
<java className="WikiVectors" classpathref="build.classpath">
<arg value="--test"/>
</java>
</target>

<target name="extract-vector-tasks">
Expand All @@ -37,12 +42,44 @@
<target name="vectors300-docs" depends="build">
<java className="WikiVectors" classpathref="build.classpath">
<arg value="${data.dir}/glove.6B.300d.txt"/>
<arg value="${data.dir}/enwiki-20120502-lines-1k.txt"/>
<arg value="${data.dir}/enwiki-20120502-lines-1k-fixed-utf8-with-random-label.txt"/>
<arg value="${data.dir}/enwiki-20120502-lines-1k-300d.vec"/>
</java>
</target>

<target name="vectors300" depends="vectors300-tasks,vectors300-docs"/>
<target name="vectors300-8bit-docs" depends="build">
<java className="WikiVectors" classpathref="build.classpath">
<!-- scale extrapolated from 100-dim by comparing 1/99%ile -->
<!--
1 Percentile = -0.16812261939048767
10 Percentile = -0.05280676484107971
50 Percentile = 0.0016044741496443748
90 Percentile = 0.052327241748571396
99 Percentile = 0.12062523514032364
100 Percentile = 0.500967800617218
average: -0.0013391887
stddev: 0.05804375
min .. max: -0.6463034 .. 0.5009678
-->
<arg value="-scale"/>
<arg value="256"/>
<arg value="${data.dir}/glove.6B.300d.txt"/>
<arg value="${data.dir}/enwiki-20120502-lines-1k-fixed-utf8-with-random-label.txt"/>
<arg value="${data.dir}/enwiki-20120502-lines-1k-300d-8bit.vec"/>
</java>
</target>

<target name="vectors300-8bit-tasks" depends="build">
<java className="WikiVectors" classpathref="build.classpath">
<arg value="-scale"/>
<arg value="256"/>
<arg value="${data.dir}/glove.6B.300d.txt"/>
<arg value="${tasks.dir}/vector.task.txt"/>
<arg value="${tasks.dir}/vector-task-300d-8bit.vec"/>
</java>
</target>

<target name="vectors300" depends="vectors300-tasks,vectors300-docs,vectors300-8bit-docs,vectors300-8bit-tasks"/>

<target name="vectors100-tasks" depends="build,extract-vector-tasks">
<java className="WikiVectors" classpathref="build.classpath">
Expand All @@ -55,11 +92,44 @@
<target name="vectors100-docs" depends="build">
<java className="WikiVectors" classpathref="build.classpath">
<arg value="${data.dir}/glove.6B.100d.txt"/>
<arg value="${data.dir}/enwiki-20120502-lines-1k.txt"/>
<arg value="${data.dir}/enwiki-20120502-lines-1k-fixed-utf8-with-random-label.txt"/>
<arg value="${data.dir}/enwiki-20120502-lines-1k-100d.vec"/>
</java>
</target>

<target name="vectors100" depends="vectors100-tasks,vectors100-docs"/>
<target name="vectors100-8bit-docs" depends="build">
<java className="WikiVectors" classpathref="build.classpath">
<!-- scale was determined empirically, testing recall -->
<!--
1 Percentile = -0.325579971075058
10 Percentile = -0.09507450535893414
50 Percentile = -0.00024916713300626725
90 Percentile = 0.10140283405780792
99 Percentile = 0.25520724058151245
100 Percentile = 0.5750676989555359
average: -0.0012661816
stddev: 0.10003442
min .. max: -0.7082266 .. 0.5750677
-->
<arg value="-scale"/>
<arg value="128"/>
<arg value="${data.dir}/glove.6B.100d.txt"/>
<arg value="${data.dir}/enwiki-20120502-lines-1k-fixed-utf8-with-random-label.txt"/>
<arg value="${data.dir}/enwiki-20120502-lines-1k-100d-8bit.vec"/>
</java>
</target>

<target name="vectors100-8bit-tasks" depends="build">
<java className="WikiVectors" classpathref="build.classpath">
<arg value="-scale"/>
<arg value="256"/>
<arg value="${data.dir}/glove.6B.100d.txt"/>
<arg value="${tasks.dir}/vector.task.txt"/>
<arg value="${tasks.dir}/vector-task-100d-8bit.vec"/>
</java>
</target>

<!-- don't scale the task data here; we'll do it later after summing vectors for each term -->
<target name="vectors100" depends="vectors100-tasks,vectors100-docs,vectors100-8bit-docs,vectors100-8bit-tasks"/>

</project>
4 changes: 4 additions & 0 deletions resources/test-dict.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
publisher -0.056504 -0.66872 -0.33651 -0.10154 1.2664 0.32571 0.24443 -0.90948 0.096117 -0.32758 0.46793 -0.56473 0.040005 -0.86401 -0.82027 -0.0025342 0.6563 -0.41941 -0.28702 0.3562 0.02351 0.27412 -0.3095 0.23099 0.79746 -0.18329 0.40602 -0.52868 -0.16673 0.54727 0.36206 0.98564 -0.52264 -0.50094 -0.61851 0.33877 -0.12586 0.0095477 0.07713 -0.3797 -0.37611 0.71953 -0.56789 -0.19281 -0.53684 -0.093614 -0.34536 -0.6573 0.36934 0.26466 0.59997 0.18451 0.52943 -0.39581 0.33493 -1.4055 -0.57196 0.34593 1.1552 0.021271 0.9859 -0.1074 -0.62339 -0.30346 1.0431 -0.93982 1.0215 1.176 0.60246 0.21486 -0.041185 0.85326 -0.34287 -0.047272 -0.77717 -0.29112 0.084178 -0.10617 -1.0464 -0.17233 0.1773 -0.21221 -0.35844 0.43236 0.26568 -1.0374 -0.13321 -0.57539 0.18294 -0.61237 0.32514 0.2414 1.1302 0.03843 -0.51238 0.42528 0.48861 -0.53346 0.39545 0.16064
publishers 0.28918 -0.55375 -0.41782 0.18842 0.9574 0.16435 -0.24372 -0.55326 0.3221 -0.83166 0.46675 -1.0865 -0.031945 -0.96278 -1.2301 0.3369 0.2618 0.077774 -0.2947 0.30948 -0.10792 0.092807 -0.060774 -0.072964 0.80845 -0.15378 0.67276 -0.29604 -0.11491 -0.15566 0.70413 0.5515 -0.58488 -0.86914 0.11524 0.49648 -0.20281 0.57394 -0.27126 -0.38978 -0.37894 0.65863 -1.5056 -0.10715 -0.38215 -0.83031 0.15389 -0.38095 0.80655 0.61584 0.47668 0.46514 -0.39508 -0.19577 0.07787 -0.59569 -0.14372 -0.099975 0.98886 0.19469 0.45553 0.37968 -0.55694 0.18355 0.731 -0.32917 0.41706 1.0847 0.4782 -0.63639 -0.069924 0.69925 -0.14858 -0.47013 -1.049 -0.13939 -0.12439 0.047415 -1.1494 -0.6763 0.2363 -0.42335 -0.66608 0.13898 -1.0781 -0.65048 -0.25339 -0.72188 -0.28868 0.19169 -0.030016 0.38162 0.29434 -1.0626 -0.80629 0.22669 0.15923 0.12407 1.0077 0.36619
backstory 0.17688 0.26544 0.48175 -0.037923 0.0225 0.26334 -0.18425 -0.11263 0.34749 0.34949 -0.096689 0.36261 0.13968 -0.52273 0.52208 0.35764 0.40101 0.40433 0.45339 0.4524 -0.12238 -0.39222 -0.45615 0.030286 -0.059574 -0.18065 0.38055 0.20366 -0.58678 0.3039 -0.21437 -0.37445 -0.16495 -0.36146 -0.27811 0.1546 -0.18351 0.29061 -0.3622 0.24528 0.43101 0.4811 0.14832 0.34707 -0.72842 0.50004 -0.1608 -0.044481 0.87151 0.30783 -0.20188 -0.14713 0.74831 0.75364 0.24787 -0.16816 0.076494 -0.46153 -0.94948 0.17161 -0.32485 0.6395 0.23709 0.19177 0.92626 -0.79694 -0.40348 -0.10001 0.093746 0.19737 -0.15884 0.12526 0.10808 -0.26331 0.4277 0.50344 0.14049 -0.085907 -0.27027 0.33762 -0.85896 0.30613 -0.069419 0.1909 -0.53789 0.41812 -0.32304 0.11027 -0.18164 0.24819 0.02426 0.50413 0.14018 -0.21246 1.0737 -0.58549 0.10076 -0.52355 -0.41074 0.89364
many -0.32914 0.82887 -0.14182 -0.27705 0.010944 0.42952 -0.56005 -0.07194 0.080524 -0.40554 0.043851 -0.31766 0.52202 -0.16149 0.043372 -0.30606 0.035574 0.10558 -0.13047 0.67779 0.45329 0.0075139 0.30743 -0.25804 0.0085955 -0.93448 0.00061153 -0.58644 0.06784 -0.019375 0.33947 0.30926 -0.39635 -0.094199 0.01055 0.52399 0.084729 0.28158 -0.33752 -0.19876 -1.1249 -0.19234 -0.012407 -0.19436 0.10601 -0.18132 0.67892 -0.14356 -0.0063351 -0.10511 -0.15675 -0.28684 -0.078341 0.82427 -0.021684 -2.1601 0.30517 -0.33368 1.7488 0.70295 -0.38371 1.611 0.43974 0.30393 0.89573 -0.093649 0.72926 0.10061 0.72969 -0.46212 -0.42695 0.020632 -0.21447 -0.19951 0.090397 0.13428 0.1011 0.13732 -1.0366 0.01218 1.309 -0.21578 -0.3893 0.22264 -2.1672 -0.048762 -0.4835 -0.35004 -0.69585 -0.25604 0.035456 -0.29506 -0.30143 -0.16867 -1.4708 -0.21042 -1.0739 -0.057574 0.62466 0.59499
2 changes: 2 additions & 0 deletions resources/test-tasks.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
publisher backstory
many geografia
190 changes: 178 additions & 12 deletions src/main/WikiVectors.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/

import java.io.BufferedReader;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.io.Reader;
Expand All @@ -29,6 +30,9 @@
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import perf.VectorDictionary;

Expand All @@ -38,39 +42,74 @@
* as a precursor for indexing vectors in benchmark runs. It's provided for "offline" (manual) use,
* and doesn't factor into benchmark execution.
*/
public class WikiVectors {
public class WikiVectors<T> {

private final VectorDictionary dict;
private final VectorDictionary<?> dict;

float scale;
int dimension;

public static void main(String[] args) throws Exception {
if (args.length != 3) {
System.err.println("usage: WikiVectors <vectorDictionary> <lineDocs> <docVectorOutput>");
if (args.length < 1) {
usage();
}
float scale = 0;
List<String> argList = List.of(args);
if (argList.get(0).equals("--test")) {
test();
return;
}
if (args.length < 3) {
usage();
}
if (args[0].equals("-scale")) {
scale = Float.parseFloat(args[1]);
argList = argList.subList(2, argList.size());
}
if (argList.size() != 3) {
usage();
}
WikiVectors wikiVectors = new WikiVectors(argList.get(0), scale);
try (OutputStream out = Files.newOutputStream(Paths.get(argList.get(2)))) {
wikiVectors.computeVectors(argList.get(1), out);
}
}

static void usage() {
System.err.println("usage: WikiVectors --test | [-scale X] <vectorDictionary> <lineDocs> <docVectorOutput>");
System.exit(-1);
}

WikiVectors(String dictFileName, float scale) throws IOException {
this.scale = scale;
if (scale == 0) {
dict = VectorDictionary.create(dictFileName);
} else {
dict = VectorDictionary.create(dictFileName, scale);
}
WikiVectors wv = new WikiVectors(new VectorDictionary(args[0]));
wv.computeVectors(args[1], args[2]);
}

WikiVectors(VectorDictionary dict) {
this.dict = dict;
void computeVectors(String lineDocFile, OutputStream out) throws IOException {
if (scale == 0) {
computeFloatVectors(lineDocFile, out);
} else {
computeByteVectors(lineDocFile, out);
}
}

void computeVectors(String lineDocFile, String outputFile) throws IOException {
void computeFloatVectors(String lineDocFile, OutputStream out) throws IOException {
int count = 0;
CharsetDecoder dec=StandardCharsets.UTF_8.newDecoder()
.onMalformedInput(CodingErrorAction.REPLACE); // replace invalid input with the UTF8 replacement character
try (OutputStream out = Files.newOutputStream(Paths.get(outputFile));
Reader r = Channels.newReader(FileChannel.open(Paths.get(lineDocFile)), dec, -1);
try (Reader r = Channels.newReader(FileChannel.open(Paths.get(lineDocFile)), dec, -1);
BufferedReader in = new BufferedReader(r)) {
String lineDoc;
byte[] buffer = new byte[dict.dimension * Float.BYTES];
ByteBuffer bbuf = ByteBuffer.wrap(buffer)
.order(ByteOrder.LITTLE_ENDIAN);
FloatBuffer fbuf = bbuf.asFloatBuffer();
while ((lineDoc = in.readLine()) != null) {
float[] dvec = dict.computeTextVector(lineDoc);
float[] dvec = (float[]) dict.computeTextVector(lineDoc);
fbuf.position(0);
fbuf.put(dvec);
out.write(buffer);
Expand All @@ -85,4 +124,131 @@ void computeVectors(String lineDocFile, String outputFile) throws IOException {
}
}

void computeByteVectors(String lineDocFile, OutputStream out) throws IOException {
int count = 0;
CharsetDecoder dec=StandardCharsets.UTF_8.newDecoder()
.onMalformedInput(CodingErrorAction.REPLACE); // replace invalid input with the UTF8 replacement character
try (Reader r = Channels.newReader(FileChannel.open(Paths.get(lineDocFile)), dec, -1);
BufferedReader in = new BufferedReader(r)) {
String lineDoc;
while ((lineDoc = in.readLine()) != null) {
byte[] vec = (byte[]) dict.computeTextVector(lineDoc);
out.write(vec);
if (++count % 10000 == 0) {
System.out.print("wrote " + count + "\n");
}
}
System.out.println("wrote " + count);
} catch (IOException e) {
System.err.println("An error occurred on line " + (count + 1));
throw e;
}
}

//-------------------------------------------------------------------------------//
// //
// Test Methods //
// //
//-------------------------------------------------------------------------------//

static void test() throws IOException {
testUnscaled();
System.out.println("testUnscaled: ok");
testScaled();
System.out.println("testScaled: ok");
}

static void testUnscaled() throws IOException {
WikiVectors wikiVectors = new WikiVectors("resources/test-dict.txt", 0);
assertEquals(100, wikiVectors.dict.dimension);
assertEquals(100, wikiVectors.dict.get("many").length);
assertEquals(0f, wikiVectors.dict.scale);
assertEquals(4, wikiVectors.dict.size());
// vectors were normalized
assertClose(1f, (float) VectorDictionary.vectorNorm(wikiVectors.dict.get("publisher")));
assertClose(1f, (float) VectorDictionary.vectorNorm(wikiVectors.dict.get("backstory")));
assertClose(1f, (float) VectorDictionary.vectorNorm(wikiVectors.dict.get("many")));
// compare ratios since these are invariant under scaling, and we normalized the input
assertClose(-0.056504f / 0.16064f, wikiVectors.dict.get("publisher")[0] / wikiVectors.dict.get("publisher")[99]);
assertClose(-0.32914f / 0.59499f, wikiVectors.dict.get("many")[0] / wikiVectors.dict.get("many")[99]);
assertThat(wikiVectors.dict.get("geografia") == null);
try (ByteArrayOutputStream out = new ByteArrayOutputStream()) {
wikiVectors.computeVectors("resources/test-tasks.txt", out);
byte[] buf = out.toByteArray();
FloatBuffer floats = ByteBuffer.wrap(buf).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
float[] vec = new float[100];
// vector for "publisher backstory"
floats.get(vec);
assertClose(vec[0] / vec[99],
(wikiVectors.dict.get("publisher")[0] + wikiVectors.dict.get("backstory")[0])
/
(wikiVectors.dict.get("publisher")[99] + wikiVectors.dict.get("backstory")[99]));
// vector for "many geografia" - geografia is not there
floats.get(vec);
assertClose(vec[0], wikiVectors.dict.get("many")[0]);
assertClose(vec[99], wikiVectors.dict.get("many")[99]);
}
}

static void testScaled() throws IOException {
float scale = 128f;
WikiVectors wikiVectors = new WikiVectors("resources/test-dict.txt", scale);
assertEquals(100, wikiVectors.dict.dimension);
assertEquals(100, wikiVectors.dict.get("many").length);
assertEquals(scale, wikiVectors.dict.scale);
assertEquals(4, wikiVectors.dict.size());
// vectors were normalized
assertClose(1f, (float) VectorDictionary.vectorNorm(wikiVectors.dict.get("publisher")));
assertClose(1f, (float) VectorDictionary.vectorNorm(wikiVectors.dict.get("backstory")));
assertClose(1f, (float) VectorDictionary.vectorNorm(wikiVectors.dict.get("many")));
// compare ratios since these are invariant under scaling, and we normalized the input
assertClose(-0.056504f / 0.16064f, wikiVectors.dict.get("publisher")[0] / wikiVectors.dict.get("publisher")[99]);
assertClose(-0.32914f / 0.59499f, wikiVectors.dict.get("many")[0] / wikiVectors.dict.get("many")[99]);
assertThat(wikiVectors.dict.get("geografia") == null);
try (ByteArrayOutputStream out = new ByteArrayOutputStream()) {
wikiVectors.computeVectors("resources/test-tasks.txt", out);
byte[] buf = out.toByteArray();
// we wrote two 100-dimensional vectors
assertEquals(200, buf.length);
// vector for "publisher backstory"
assertClose(buf[0] / (float) buf[99],
(wikiVectors.dict.get("publisher")[0] + wikiVectors.dict.get("backstory")[0])
/
(float) (wikiVectors.dict.get("publisher")[99] + wikiVectors.dict.get("backstory")[99]),
1/128f);
// vector for "many geografia" - geografia is not there
assertEquals(buf[100], scaleToByte(wikiVectors.dict.get("many")[0], scale));
assertEquals(buf[199], scaleToByte(wikiVectors.dict.get("many")[99], scale));
}
}

private static byte scaleToByte(float f, float scale) {
return (byte) Math.min(Math.max(f * scale, -128), 127);
}

private static void assertClose(float a, float b) {
assertClose(a, b, 1e-5f);
}

private static void assertClose(float a, float b, float tolerance) {
if (Math.abs(a - b) > tolerance) {
fail(a + " is not close to " + b);
}
}

private static void assertEquals(Object a, Object b) {
if (!a.equals(b)) {
fail(a + " is not equal to " + b);
}
}

private static void assertThat(boolean condition) {
if (!condition) {
fail("condition was not true");
}
}

private static void fail(String message) {
throw new AssertionError(message);
}
}
8 changes: 8 additions & 0 deletions src/main/perf/Args.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ public float getFloat(String argName) {
return Float.parseFloat(getString(argName));
}

public Float getFloat(String argName, Float defaultValue) {
String arg = getString(argName, null);
if (arg == null) {
return defaultValue;
}
return Float.parseFloat(arg);
}

public long getLong(String argName) {
return Long.parseLong(getString(argName));
}
Expand Down
Loading

0 comments on commit e33c95b

Please sign in to comment.