Skip to content

Commit

Permalink
Avoid trying to backprop through non-differentiable layers.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Oct 14, 2023
1 parent 8921d50 commit 3af3702
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions candle-core/src/backprop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ impl Tensor {
// Do not call recursively on the "leaf" nodes.
track_grad = true;
nodes
} else if node.dtype().is_int() {
nodes
} else if let Some(op) = node.op() {
match op {
Op::IndexAdd(t1, t2, t3, _)
Expand Down Expand Up @@ -103,7 +105,6 @@ impl Tensor {
| Op::Broadcast(node)
| Op::Cmp(node, _)
| Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
| Op::ToDType(node)
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
| Op::Permute(node, _)
Expand All @@ -116,6 +117,15 @@ impl Tensor {
track_grad |= tg;
nodes
}
Op::ToDType(node) => {
if node.dtype().is_float() {
let (tg, nodes) = walk(node, nodes, already_seen);
track_grad |= tg;
nodes
} else {
nodes
}
}
Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
}
} else {
Expand Down Expand Up @@ -374,7 +384,7 @@ impl Tensor {
}
Op::ToDType(arg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
*sum_grad = sum_grad.add(&grad.to_dtype(arg.dtype())?)?
}
Op::Copy(arg) => {
let sum_grad = grads.or_insert(arg)?;
Expand Down

0 comments on commit 3af3702

Please sign in to comment.