Skip to content

Commit

Permalink
Move Shape to its own file.
Browse files Browse the repository at this point in the history
Expand no longer adds dimensions.
  • Loading branch information
kurtschelfthout committed May 1, 2023
1 parent baa7262 commit 5c0ac45
Show file tree
Hide file tree
Showing 12 changed files with 314 additions and 101 deletions.
45 changes: 45 additions & 0 deletions examples/eye.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use tensorken::tensor::Cpu32;

/// A macro to print the result of an expression and the expression itself.
macro_rules! do_example {
($e:expr) => {
println!(">>> {}", stringify!($e));
let result = $e;
println!("{result}");
};
($e:expr, $debug:literal) => {
println!(">>> {}", stringify!($e));
let result = $e;
println!("{result:?}");
};
}

/// A macro to print the result of an expression, the expression itself,
/// and bind the result to a variable.
macro_rules! let_example {
($t:ident, $e:expr) => {
println!(">>> {}", stringify!(let $t = $e));
let $t = $e;
println!("{}", $t);
};
($t:ident, $e:expr, $debug:literal) => {
println!(">>> {}", stringify!(let $t = $e));
let $t = $e;
println!("{:?}", $t);
};
}

type Tr = Cpu32;

fn main() {
// how to make an eye
let_example!(dim, 3);
do_example!(&Tr::eye(dim));
let_example!(t, &Tr::scalar(1.0));
let_example!(t, t.pad(&[(0, dim)]));
let_example!(t, t.reshape(&[1, dim + 1]));
let_example!(t, t.expand(&[dim, dim + 1]));
let_example!(t, t.reshape(&[dim * (dim + 1)]));
let_example!(t, t.crop(&[(0, dim * dim)]));
let_example!(t, t.reshape(&[dim, dim]));
}
69 changes: 69 additions & 0 deletions examples/matmul.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use tensorken::{shape::Shape, tensor::Cpu32};

/// A macro to print the result of an expression and the expression itself.
macro_rules! do_example {
($e:expr) => {
println!(">>> {}", stringify!($e));
let result = $e;
println!("{result}");
};
($e:expr, $debug:literal) => {
println!(">>> {}", stringify!($e));
let result = $e;
println!("{result:?}");
};
}

/// A macro to print the result of an expression, the expression itself,
/// and bind the result to a variable.
macro_rules! let_example {
($t:ident, $e:expr) => {
println!(">>> {}", stringify!(let $t = $e));
let $t = $e;
println!("{}", $t);
};
($t:ident, $e:expr, $debug:literal) => {
println!(">>> {}", stringify!(let $t = $e));
let $t = $e;
println!("{:?}", $t);
};
}

type Tr = Cpu32;

fn main() {
// how to multiply matrices, the hard way
let_example!(l, Tr::linspace(0.0, 11.0, 12).reshape(&[3, 4]));
let_example!(r, Tr::linspace(12.0, 23.0, 12).reshape(&[4, 3]));
do_example!(&l.matmul(&r));

// left's shape from [..., m, n] to [..., m, 1, n]
let_example!(s, l.shape(), true);
let_example!(
l_shape,
[&s[..s.ndims() - 1], &[1, s[s.ndims() - 1]]].concat(),
true
);
let_example!(l, l.reshape(&l_shape));

// right's shape from [..., n, o] to [..., 1, o, n]
let_example!(s, r.shape(), true);
let_example!(
r_shape,
[&s[..s.ndims() - 2], &[1], &s[s.ndims() - 2..]].concat(),
true
);
let_example!(
r,
r.reshape(&r_shape)
.transpose(r_shape.ndims() - 1, r_shape.ndims() - 2)
);

// after multiply: [..., m, o, n]
let_example!(prod, &l * &r);
// after sum: [..., m, o, 1]
let_example!(sum, prod.sum(&[prod.shape().ndims() - 1]));
// after reshape: [..., m, o]
let_example!(s, sum.shape(), true);
do_example!(sum.reshape(&s[..s.ndims() - 1]));
}
112 changes: 112 additions & 0 deletions examples/tour.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
use tensorken::raw_tensor_cpu::CpuRawTensor;
use tensorken::tensor::{Cpu32, IndexValue, Tensor};

/// A macro to print the result of an expression and the expression itself.
macro_rules! do_example {
($e:expr) => {
println!(">>> {}", stringify!($e));
let result = $e;
println!("{result}");
};
($e:expr, $debug:literal) => {
println!(">>> {}", stringify!($e));
let result = $e;
println!("{result:?}");
};
}

/// A macro to print the result of an expression, the expression itself,
/// and bind the result to a variable.
macro_rules! let_example {
($t:ident, $e:expr) => {
println!(">>> {}", stringify!(let $t = $e));
let $t = $e;
println!("{}", $t);
};
($t:ident, $e:expr, $debug:literal) => {
println!(">>> {}", stringify!(let $t = $e));
let $t = $e;
println!("{:?}", $t);
};
}

type Tr = Cpu32;

fn main() {
do_example!(Tensor::<CpuRawTensor<f32>>::new(
&[3, 2],
&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
));

// unary operations
let_example!(t, &Tr::new(&[3, 2], &[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]));

do_example!(t.exp());
do_example!(t.log());

// binary operations
let_example!(t1, &Tr::new(&[2, 2], &[0.0, 1.0, 2.0, 3.0]));
let_example!(t2, &Tr::new(&[2, 2], &[6.0, 7.0, 8.0, 9.0]));

do_example!(t1 + t2);
do_example!(t1 * t2);
do_example!(t1.matmul(t2));

// broadcasting
let_example!(t1, &Tr::new(&[6], &[2.0, 1.0, 4.0, 2.0, 8.0, 4.0]));
let_example!(s1, &Tr::scalar(2.0));
do_example!((t1.shape(), s1.shape()), true);
do_example!(t1 + s1);

let_example!(t1, &Tr::new(&[3, 2], &[2.0, 1.0, 4.0, 2.0, 8.0, 4.0]));
let_example!(t2, &Tr::new(&[1, 2], &[10.0, 100.0]));
do_example!(t1 + t2);
let_example!(t3, &Tr::new(&[3, 1], &[10.0, 100.0, 1000.]));
do_example!(t1 + t3);

let_example!(t1, &Tr::new(&[2, 3], &[2.0, 1.0, 4.0, 2.0, 8.0, 4.0]));
let_example!(s1, &Tr::scalar(2.0));
do_example!((t1.shape(), s1.shape()), true);
do_example!(t1 + s1);

let_example!(t1, &Tr::new(&[3, 2], &[2.0, 1.0, 4.0, 2.0, 8.0, 4.0]));
let_example!(t2, &Tr::new(&[2], &[10.0, 100.0]));
do_example!((t1.shape(), t2.shape()), true);
do_example!(t1 + t2);

// reduce operations
let_example!(t, &Tr::new(&[4], &[0.0, 1.0, 2.0, 3.0]));
do_example!(t.sum(&[0]));

let_example!(t, &Tr::new(&[2, 2], &[0.0, 1.0, 2.0, 3.0]));
do_example!(t.sum(&[0, 1]));
do_example!(t.sum(&[0]));
do_example!(t.sum(&[1]));

// movement ops/slicing and dicing
let_example!(t, &Tr::new(&[1, 2, 2], &[0.0, 1.0, 2.0, 3.0]));
do_example!(t.expand(&[5, 2, 2]));

let_example!(t, &Tr::new(&[3, 2], &[2.0, 1.0, 4.0, 2.0, 8.0, 4.0]));
do_example!(t.crop(&[(0, 2), (1, 2)]));

let_example!(t, &Tr::new(&[3, 2], &[2.0, 1.0, 4.0, 2.0, 8.0, 4.0]));
do_example!(t.pad(&[(1, 2), (1, 3)]));

let_example!(t, &Tr::new(&[2, 2], &[0.0, 1.0, 2.0, 3.0]));
do_example!(t.at(1));
do_example!(t.at(&[1, 0]));

let_example!(t, Tr::linspace(0.0, 23.0, 24));
let_example!(t6x4, t.reshape(&[6, 4]));
let_example!(t3x8, t6x4.reshape(&[3, 8]));

do_example!(t3x8.permute(&[1, 0]));

// broadcasting with matmul in more than 2 dimensions boggles the mind
let_example!(t1, &Tr::linspace(1.0, 36.0, 36).reshape(&[3, 2, 2, 3]));
let_example!(t2, &Tr::linspace(37.0, 72.0, 36).reshape(&[3, 2, 3, 2]));
do_example!(t1.matmul(t2));
let_example!(t3, &Tr::linspace(39.0, 72.0, 12).reshape(&[2, 3, 2]));
do_example!(t1.matmul(t3));
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod num;
pub mod raw_tensor;
pub mod raw_tensor_cpu;
pub mod raw_tensor_wgpu;
pub mod shape;
mod shape_strider;
pub mod tensor;
pub mod tensor_mut;
5 changes: 1 addition & 4 deletions src/raw_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@ use crate::{num::Num, raw_tensor_cpu::CpuRawTensor};
/// As such it can be used to implement a new type of accelerator, but can also support
/// optimizations like fusing.
/// Think of `RawTensor` as the DSL for accelerators, in final style.
pub trait RawTensor
where
Self: Sized,
{
pub trait RawTensor {
// Note: Elem is an associated type, not a generic parameter, for rather subtle reasons.
// We often want to implement traits for e.g. Tensor<impl RawTensor> without having to mention
// the element type, as the element type is not restricted by the implementation. See e.g. Add, Neg on Tensor:
Expand Down
23 changes: 12 additions & 11 deletions src/raw_tensor_cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use std::rc::Rc;

use crate::num::Num;
use crate::raw_tensor::RawTensor;
use crate::shape_strider::{Shape, ShapeStrider, TensorIndexIterator};
use crate::shape::Shape;
use crate::shape_strider::{ShapeStrider, TensorIndexIterator};

/// Implementation of `RawTensor` for CPU.
/// The "numpy" part of the tensor library.
Expand Down Expand Up @@ -352,25 +353,25 @@ mod tests {
#[test]
fn test_expand_scalar() {
let t = CpuRawTensor::new_into(&[1], vec![42.0]);
let t = t.expand(&[5, 4]);
let t = t.expand(&[4]);

assert_eq!(t.shape(), &[5, 4]);
assert_eq!(t.strides(), &[0, 0]);
assert_eq!(t.ravel(), repeat(42.0).take(20).collect::<Vec<_>>());
assert_eq!(t.shape(), &[4]);
assert_eq!(t.strides(), &[0]);
assert_eq!(t.ravel(), repeat(42.0).take(4).collect::<Vec<_>>());
}

#[test]
fn test_expand_3x1() {
let t = CpuRawTensor::new_into(&[3, 1], make_vec(3));
let t = t.expand(&[15, 3, 5]);
let t = t.expand(&[3, 5]);

assert_eq!(t.shape(), &[15, 3, 5]);
assert_eq!(t.strides(), &[0, 1, 0]);
assert_eq!(t.shape(), &[3, 5]);
assert_eq!(t.strides(), &[1, 0]);
}

#[test]
fn test_expand_2x3x4() {
let t = CpuRawTensor::new_into(&[2, 3, 4], make_vec(24));
fn test_expand_1x2x3x4() {
let t = CpuRawTensor::new_into(&[1, 2, 3, 4], make_vec(24));
let t = t.expand(&[5, 2, 3, 4]);

assert_eq!(t.shape(), &[5, 2, 3, 4]);
Expand Down Expand Up @@ -427,7 +428,7 @@ mod tests {

#[test]
fn test_binary_ops_different_strides() {
let t1 = CpuRawTensor::new_into(&[1], vec![20.0]).expand(&[2, 3]);
let t1 = CpuRawTensor::new_into(&[1, 1], vec![20.0]).expand(&[2, 3]);
let t2 = CpuRawTensor::new_into(&[2, 3], make_vec(6));
let t = t1.add(&t2);
assert_eq!(t.shape(), &[2, 3]);
Expand Down
8 changes: 3 additions & 5 deletions src/raw_tensor_wgpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@ use bytemuck::{NoUninit, Pod};
use wgpu::util::DeviceExt;

use crate::{
num::Num,
raw_tensor::RawTensor,
raw_tensor_cpu::CpuRawTensor,
shape_strider::{Shape, ShapeStrider},
num::Num, raw_tensor::RawTensor, raw_tensor_cpu::CpuRawTensor, shape::Shape,
shape_strider::ShapeStrider,
};

// Misc WGSL notes/tips:
Expand Down Expand Up @@ -828,7 +826,7 @@ mod tests {

#[test]
fn test_expand() {
let t = WgpuRawTensor::new(&[1], &[42.0], get_wgpu_device());
let t = WgpuRawTensor::new(&[1, 1], &[42.0], get_wgpu_device());
let t = t.expand(&[5, 4]);

assert_eq!(t.shape(), &[5, 4]);
Expand Down
32 changes: 32 additions & 0 deletions src/shape.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/// A trait for types that can be used as shapes for tensors,
/// with some convenience methods for working with shapes.
pub trait Shape {
/// Returns the shape as a slice.
fn shape(&self) -> &[usize];

/// Returns the number of dimensions.
fn ndims(&self) -> usize {
self.shape().len()
}

/// Returns the total number of elements.
fn size(&self) -> usize {
if self.ndims() == 0 {
0
} else {
self.shape().iter().product()
}
}
}

impl Shape for &[usize] {
fn shape(&self) -> &[usize] {
self
}
}

impl Shape for Vec<usize> {
fn shape(&self) -> &[usize] {
self
}
}
Loading

0 comments on commit 5c0ac45

Please sign in to comment.