diff --git a/dev/test/test_dataloader.c b/dev/test/test_dataloader.c index f400c986d..2803da022 100644 --- a/dev/test/test_dataloader.c +++ b/dev/test/test_dataloader.c @@ -7,12 +7,12 @@ gcc -O3 -I../../llmc -o test_dataloader test_dataloader.c -lm && ./test_dataload TODOs: - test load/save state of DataLoader */ - +#include #include "../../llmc/dataloader.h" #define SHARD_NAME_LEN 64 char shard_name[SHARD_NAME_LEN]; -int num_tokens = 140; +const int num_tokens = 140; int num_shards = 4; void check_range(const int *tokens, const int start, const int end, const char *file, int line) { @@ -181,7 +181,8 @@ void test_shuffled(void) { checkEquals(num_seen_inputs + start + tokens_fit, num_tokens - tokens_fit, 0); // verify the target counts. same thing but offset by 1 checkEquals(num_seen_targets + start + 1, tokens_fit, num_epochs); - checkEquals(num_seen_targets + start + 1 + tokens_fit, num_tokens - tokens_fit, 0); + checkEquals(num_seen_targets + start + 1 + tokens_fit, + (s == (num_shards - 1)) ? num_tokens - tokens_fit - 1 : num_tokens - tokens_fit,0); } dataloader_free(&loader); @@ -204,7 +205,7 @@ void test_multiprocess_shuffled(void) { printf("test_multiprocess_shuffled... "); int B = 4; int T = 8; - int num_processes = 2; + const int num_processes = 2; int should_shuffle = 0; snprintf(shard_name, SHARD_NAME_LEN, "shard_????.bin"); DataLoader loaders[num_processes]; @@ -252,7 +253,8 @@ void test_multiprocess_shuffled(void) { checkEquals(num_seen_inputs + start + tokens_fit, num_tokens - tokens_fit, 0); // verify the target counts. same thing but offset by 1 checkEquals(num_seen_targets + start + 1, tokens_fit, num_epochs); - checkEquals(num_seen_targets + start + 1 + tokens_fit, num_tokens - tokens_fit, 0); + checkEquals(num_seen_targets + start + 1 + tokens_fit, + (s == (num_shards - 1)) ? num_tokens - tokens_fit - 1 : num_tokens - tokens_fit,0); } // cleanup diff --git a/dev/unistd.h b/dev/unistd.h index 337f29ad2..411757f43 100644 --- a/dev/unistd.h +++ b/dev/unistd.h @@ -7,8 +7,11 @@ #include #include -//#define gen_max_length 64 // compile as C++ to skip this VLA issue #include +#include // for malloc and free +#include +#include // for _mkdir and _stat +#include // needed for _access below and _findfirst, _findnext, _findclose #define CLOCK_MONOTONIC 0 static inline int clock_gettime(int ignore_variable, struct timespec* tv) @@ -17,14 +20,12 @@ static inline int clock_gettime(int ignore_variable, struct timespec* tv) } #define OMP /* turn it on */ -#include /* needed for access below */ #define F_OK 0 #define access _access #define TURN_OFF_FP_FAST __pragma(float_control( precise, on, push )) // Save current setting and turn on /fp:precise #define TURN_ON_FP_FAST __pragma(float_control(pop)) // Restore file's default settings -#include /* for _mkdir and _stat */ #define mkdir(path, mode) _mkdir(path) /* sketchy way to get mkdir to work on windows */ #define stat _stat @@ -59,7 +60,7 @@ static inline int glob(const char* pattern, int ignored_flags, int (*ignored_err replace_forward_slashes (pattern_copy); // Replace forward slashes with backslashes - if (strchr(pattern_copy, '\\') != NULL) { + if (strchr(pattern_copy, '\\') != (void*) NULL) { strncpy_s(directory_path, sizeof(directory_path) - 1, pattern_copy, strrchr(pattern_copy, '\\') - pattern_copy + 1); directory_path[strrchr(pattern_copy, '\\') - pattern_copy + 1] = '\0'; }