Skip to content

Commit

Permalink
Refine load_records
Browse files Browse the repository at this point in the history
  • Loading branch information
primenumber committed Jun 23, 2024
1 parent 945a4e6 commit dd3ae08
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 51 deletions.
48 changes: 11 additions & 37 deletions src/record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::engine::hand::*;
use anyhow::Result;
use std::fmt::*;
use std::fs::File;
use std::io::{BufRead, BufReader, Read};
use std::io::{BufRead, BufReader};
use std::path::Path;
use std::str::FromStr;
use thiserror::Error;
Expand Down Expand Up @@ -46,11 +46,7 @@ impl Record {
let mut board = self.initial_board;
let mut res = Vec::new();
let final_score = self.final_score.ok_or(ScoreIsNotRegistered {})?;
let mut score = if self.hands.len() % 2 == 0 {
final_score
} else {
-final_score
};
let mut score = final_score;
for &h in &self.hands {
res.push((board, h, score));
board = board.play_hand(h).ok_or(UnmovableError {})?;
Expand Down Expand Up @@ -104,44 +100,22 @@ impl FromStr for Record {
let score = if let Some(score) = splitted.get(1) {
score.parse().ok()
} else if board.is_gameover() {
Some(board.score() as i16)
let absolute_score = if l % 2 == 0 {
board.score()
} else {
-board.score()
};
Some(absolute_score as i16)
} else {
None
};
Ok(Record::new(Board::initial_state(), &hands, score))
}
}

pub struct LoadRecords<R: Read> {
reader: BufReader<R>,
buffer: String,
remain: usize,
}

impl<R: Read> Iterator for LoadRecords<R> {
type Item = Result<Record>;
fn next(&mut self) -> Option<Self::Item> {
if self.remain > 0 {
self.remain -= 1;
self.reader.read_line(&mut self.buffer).ok()?;
return Some(self.buffer.parse::<Record>().map_err(|e| e.into()));
}
None
}
}

pub fn load_records(path: &Path) -> Result<LoadRecords<File>> {
pub fn load_records(path: &Path) -> Result<impl Iterator<Item = Result<Record, ParseRecordError>>> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let mut buffer = String::new();

reader.read_line(&mut buffer)?;
let remain = buffer.trim().parse()?;
buffer.clear();
let reader = BufReader::new(file);

Ok(LoadRecords {
reader,
buffer,
remain,
})
Ok(reader.lines().map(|line| line.unwrap().parse::<Record>()))
}
40 changes: 26 additions & 14 deletions src/train.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ use std::io::{BufRead, BufReader, BufWriter, Write};
use std::path::Path;
use std::str;
use std::sync::Arc;
use rand::prelude::*;

pub fn clean_record(matches: &ArgMatches) {
let input_path = matches.get_one::<String>("INPUT").unwrap();
let output_path = matches.get_one::<String>("OUTPUT").unwrap();

let mut result = Vec::new();
for record in load_records(Path::new(input_path)).unwrap() {
if let Ok(record) = record {
if let Ok(_timeline) = record.timeline() {
result.push(record);
}
let Ok(record) = record else { continue; };
if let Ok(_timeline) = record.timeline() {
result.push(record);
}
}

Expand All @@ -47,18 +47,20 @@ pub fn gen_dataset(matches: &ArgMatches) {
.unwrap()
.parse::<usize>()
.unwrap();
let mut rng = rand::thread_rng();

eprintln!("Parse input...");
let mut boards_with_results = Vec::new();
let input_file = File::open(Path::new(input_path)).unwrap();
for line in BufReader::new(input_file).lines() {
let record = line.unwrap().parse::<Record>().unwrap();
for record in load_records(Path::new(input_path)).unwrap() {
let record = record.unwrap();
let mut timeline = record.timeline().unwrap();
boards_with_results.append(&mut timeline);
}

eprintln!("Total board count = {}", boards_with_results.len());

boards_with_results.shuffle(&mut rng);

eprintln!("Writing to file...");
let out_f = File::create(output_path).unwrap();
let mut writer = BufWriter::new(out_f);
Expand All @@ -73,13 +75,23 @@ pub fn gen_dataset(matches: &ArgMatches) {
if idx >= max_output {
break;
}
if let Hand::Play(pos) = hand {
writeln!(
&mut writer,
"{:016x} {:016x} {} {}",
board.player, board.opponent, score, pos,
)
.unwrap();
match hand {
Hand::Play(pos) => {
writeln!(
&mut writer,
"{:016x} {:016x} {} {}",
board.player, board.opponent, score, pos,
)
.unwrap();
}
Hand::Pass => {
writeln!(
&mut writer,
"{:016x} {:016x} {} ps",
board.player, board.opponent, score,
)
.unwrap();
}
}
}
eprintln!("Finished!");
Expand Down

0 comments on commit dd3ae08

Please sign in to comment.