Skip to content

Commit

Permalink
Add the pooling operators to the pyo3 layer.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Oct 13, 2023
1 parent 75989fc commit 4f76c41
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 0 deletions.
2 changes: 2 additions & 0 deletions candle-pyo3/py_src/candle/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Generated content DO NOT EDIT
from .. import functional

avg_pool2d = functional.avg_pool2d
gelu = functional.gelu
max_pool2d = functional.max_pool2d
relu = functional.relu
silu = functional.silu
softmax = functional.softmax
Expand Down
14 changes: 14 additions & 0 deletions candle-pyo3/py_src/candle/functional/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,27 @@ from os import PathLike
from candle.typing import _ArrayLike, Device
from candle import Tensor, DType, QTensor

@staticmethod
def avg_pool2d(tensor: Tensor, ksize: int, stride: int = 1) -> Tensor:
"""
Applies the 2d avg-pool function to a given tensor.#
"""
pass

@staticmethod
def gelu(tensor: Tensor) -> Tensor:
"""
Applies the Gaussian Error Linear Unit (GELU) function to a given tensor.
"""
pass

@staticmethod
def max_pool2d(tensor: Tensor, ksize: int, stride: int = 1) -> Tensor:
"""
Applies the 2d max-pool function to a given tensor.#
"""
pass

@staticmethod
def relu(tensor: Tensor) -> Tensor:
"""
Expand Down
24 changes: 24 additions & 0 deletions candle-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1224,6 +1224,28 @@ fn softmax(tensor: PyTensor, dim: i64) -> PyResult<PyTensor> {
Ok(PyTensor(sm))
}

#[pyfunction]
#[pyo3(signature = (tensor, ksize, *, stride=1), text_signature = "(tensor:Tensor, ksize:int, stride:int=1)")]
/// Applies the 2d avg-pool function to a given tensor.#
/// &RETURNS&: Tensor
fn avg_pool2d(tensor: PyTensor, ksize: usize, stride: usize) -> PyResult<PyTensor> {
let tensor = tensor
.avg_pool2d_with_stride(ksize, stride)
.map_err(wrap_err)?;
Ok(PyTensor(tensor))
}

#[pyfunction]
#[pyo3(signature = (tensor, ksize, *, stride=1), text_signature = "(tensor:Tensor, ksize:int, stride:int=1)")]
/// Applies the 2d max-pool function to a given tensor.#
/// &RETURNS&: Tensor
fn max_pool2d(tensor: PyTensor, ksize: usize, stride: usize) -> PyResult<PyTensor> {
let tensor = tensor
.max_pool2d_with_stride(ksize, stride)
.map_err(wrap_err)?;
Ok(PyTensor(tensor))
}

#[pyfunction]
#[pyo3(text_signature = "(tensor:Tensor)")]
/// Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
Expand Down Expand Up @@ -1263,6 +1285,8 @@ fn tanh(tensor: PyTensor) -> PyResult<PyTensor> {
fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(silu, m)?)?;
m.add_function(wrap_pyfunction!(softmax, m)?)?;
m.add_function(wrap_pyfunction!(max_pool2d, m)?)?;
m.add_function(wrap_pyfunction!(avg_pool2d, m)?)?;
m.add_function(wrap_pyfunction!(gelu, m)?)?;
m.add_function(wrap_pyfunction!(relu, m)?)?;
m.add_function(wrap_pyfunction!(tanh, m)?)?;
Expand Down

0 comments on commit 4f76c41

Please sign in to comment.