forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathderivatives.yaml
1488 lines (1087 loc) · 69.8 KB
/
derivatives.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Defines derivative formulas and Python signatures of methods on Variable
#
# Each entry consists of:
# - A 'name', which specifies the ATen name of the function you
# are defining derivatives for, and an argument specification.
# - One or more gradients entries, mapping a differentiable input
# names to a formula specifying how to compute its gradient.
# Note that a single gradient entry can specify the gradient
# formula for multiple input names, by specifying a key
# "input1, input2" (see atan2 for an example).
# - An argument can be flagged as 'non_differentiable'.
# In general there are 3 possibilities:
# 1. An argument has an entry with a specified gradient
# 2. An argument has an entry specified as not differentiable
# 3. An argument has no entry
# Using the flag 'non_differentiable' resolves to the second case.
# The second case was introduced in support for arguments of
# type e.g. IndexTensor for 'embedding', that are not differentiable.
# TODO: Determine whether case 3 and case 2 can be replaced by one concept.
# - Optional entry with key 'output_differentiability' and value a list of the
# same length as the number of outputs from the forward function. The list
# should contain only booleans, specifying whether each of the output Tensor
# is differentiable.
# If None of the output is differentiable, you can also add the function
# name to `gen_variable_type.py`'s `DONT_REQUIRE_DERIVATIVE` list.
#
# If a function has out-of-place and in-place variants, then the derivative
# definition for the in-place variant is optional. It will default to the
# definition for the out-of-place variant. Similarly, _out variants will
# default to the derivative for the non _out variant.
#
# Gradient expressions are standard C++ expressions operating on ATen
# variables. In a gradient expression, the following variables are in
# scope:
#
# - 'grad', the gradient of the output (often spelled grad_output
# in Python) which we are going to left-multiply.
#
# When a function returns multiple *differentiable* outputs,
# you can refer to the gradients of each outputs using 'grads',
# e.g., 'grads[0]', 'grads[1]'.
#
# When a function returns *one* differentiable output (the
# first output) and some more nondifferentiable outputs,
# you MUST refer to the gradient of the differentiable output with
# 'grad' (this case is special-cased in our code generation).
#
# Note that the number of differentibale outputs can be modified by the
# 'output_differentiability' entry (see above).
#
# - Any of the input arguments, tensor or non-tensor, including
# argument names that only appear in Declarations.cwrap, e.g. 'output'.
#
# - 'result', representing the result of evaluating the forward
# expression for ATen native function declarations. If the forward
# expression outputs a tuple, use 'resultX' instead to access the
# X-th entry
#
# - 'grad_input_mask', a std::array<bool, n>, specifies which input
# gradients are actually needed. For example, in the entry
# `input0, input1: foo(grad_input_mask)`, `grad_input_mask` is a size
# two array, where `grad_input_mask[0]` is true if `input0` requires
# grad, and `grad_input_mask[1]` is true if `input1` requires grad.
#
# (NB: if your function computes gradient for a list of tensors,
# the `grad_input_mask` will only have a single entry for the list
# specifying if either zero or at least one tensor from the list requires
# grad. If we want to support more fine-grained signalling,
# we'll need some alternate variable which is not a std::array)
#
# - 'retain_variables', a bool which is true if a user has specified
# that saved variables should be retained in case the backwards is
# run again later. This allows an optimization where we can
# destroy saved buffers if we know variables are not going to be retained,
# e.g., it is used by _cudnn_rnn
#
# If you need a complex expression, e.g., with local variables,
# write a _backward function in tools/autograd/templates/Functions.cpp
# and invoke it from here. By the way, go read
# https://github.com/zdevito/ATen/issues/163; this describes an
# important hazard that occurs when porting backwards from Python to C++
#
# Double backwards gradient expressions can be somewhat confusing;
# the most important thing to remember is: (1) you need to define a
# derivative formula for every input, including inputs named things
# like 'grad_output', and (2) the gradient to multiply with is always
# called 'grad' (even though it really is a grad-grad).
#
# NB: There are a number of gradient definitions in here which are bogus
# (implemented using zeros_like). These gradients are (hopefully) not
# used by our frontend. You MUST check the frontend code; search for
# OpName.apply to see if it's still using a legacy Python style API.
#
# NB: The parameter names here MUST be consistent with the parameter names
# in ./torch/lib/ATen/Declarations.cwrap
- name: abs(Tensor self)
self: grad * self.sign()
- name: acos(Tensor self)
self: grad * -((-self * self + 1).rsqrt())
- name: add(Tensor self, Tensor other, *, Scalar alpha)
self: grad
other: maybe_multiply(grad, alpha)
- name: add(Tensor self, Scalar other, *, Scalar alpha)
self: grad
- name: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta, Scalar alpha)
self: maybe_multiply(grad, beta)
batch1: grad.unsqueeze(0).expand({ batch1.size(0), batch1.size(1), batch2.size(2) }).bmm(batch2.transpose(1, 2)) * alpha
batch2: batch1.transpose(1, 2).bmm(grad.unsqueeze(0).expand({ batch1.size(0), batch1.size(1), batch2.size(2) })) * alpha
- name: addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value)
self: grad
tensor1: grad * value / tensor2
tensor2: -grad * value * tensor1 / (tensor2 * tensor2)
- name: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value)
self: grad
tensor1: grad * tensor2 * value
tensor2: grad * tensor1 * value
- name: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha)
self: maybe_multiply(grad, beta)
mat1: mm_mat1_backward(grad, mat2, mat1, alpha)
mat2: mm_mat2_backward(grad, mat1, mat2.sizes(), mat2.strides(), alpha)
- name: _sparse_addmm(Tensor self, Tensor sparse, Tensor dense, *, Scalar beta, Scalar alpha)
self: maybe_multiply(grad, beta)
sparse: _sparse_addmm_sparse_backward(grad, sparse, dense, alpha)
dense: mm_mat2_backward(grad, sparse, dense.sizes(), dense.strides(), alpha)
- name: addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta, Scalar alpha)
self: maybe_multiply(grad, beta)
mat: grad.ger(vec) * alpha
vec: mat.t().mv(grad) * alpha
- name: addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta, Scalar alpha)
self: maybe_multiply(grad, beta)
vec1: grad.mv(vec2) * alpha
vec2: grad.t().mv(vec1) * alpha
- name: affine_grid_generator(Tensor theta, IntArrayRef size)
theta: affine_grid_generator_backward(grad, size)
- name: alias(Tensor self)
self: grad
# The four items below are necessary because TensorIterator doesn't work on
# Variables (codegen does not unwrap the input Tensor for all() and any() ).
- name: any(Tensor self)
self: not_implemented("any")
- name: any(Tensor self, int64_t dim, bool keepdim)
self: not_implemented("any")
- name: all(Tensor self)
self: not_implemented("all")
- name: all(Tensor self, int64_t dim, bool keepdim)
self: not_implemented("all")
- name: as_strided(Tensor self, IntArrayRef size, IntArrayRef stride, int64_t? storage_offset)
self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset)
- name: asin(Tensor self)
self: grad * (-self * self + 1).rsqrt()
- name: atan(Tensor self)
self: grad / (self * self + 1)
- name: atan2(Tensor self, Tensor other)
self, other: atan2_backward(grad, self, other, grad_input_mask)
- name: baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta, Scalar alpha)
self: maybe_multiply(grad, beta)
batch1: grad.bmm(batch2.transpose(1, 2)) * alpha
batch2: batch1.transpose(1, 2).bmm(grad) * alpha
- name: bernoulli(Tensor self, Generator generator)
self: zeros_like(grad)
- name: bernoulli_(Tensor self, Tensor p, Generator generator)
self: zeros_like(grad)
p: zeros_like(p)
- name: bernoulli_(Tensor self, double p, Generator generator)
self: zeros_like(grad)
- name: bmm(Tensor self, Tensor mat2)
self: grad.bmm(mat2.transpose(1, 2))
mat2: self.transpose(1, 2).bmm(grad)
- name: cat(TensorList tensors, int64_t dim)
tensors: cat_tensors_backward(grad, to_args_sizes(tensors), dim)
- name: cauchy_(Tensor self, double median, double sigma, Generator generator)
self: zeros_like(grad)
- name: ceil(Tensor self)
self: zeros_like(grad)
- name: cholesky(Tensor self, bool upper)
self: cholesky_backward(grad, upper, result)
- name: cholesky_solve(Tensor self, Tensor input2, bool upper)
self: not_implemented("cholesky_solve")
input2: not_implemented("cholesky_solve")
- name: cholesky_inverse(Tensor self, bool upper)
self: not_implemented("cholesky_inverse")
# For clamp, gradient is not defined at the boundaries. But empirically it's helpful
# to be able to get gradient on min and max, so we return the subgradient 1 for these cases.
- name: clamp(Tensor self, Scalar? min, Scalar? max)
self: clamp_backward(grad, self, min, max)
- name: clamp_min(Tensor self, Scalar min)
self: grad * (self >= min).to(grad.dtype())
- name: clamp_max(Tensor self, Scalar max)
self: grad * (self <= max).to(grad.dtype())
- name: clone(Tensor self)
self: grad
- name: coalesce(Tensor self)
self: grad
- name: cos(Tensor self)
self: grad * -self.sin()
- name: cosh(Tensor self)
self: grad * self.sinh()
- name: cross(Tensor self, Tensor other, int64_t? dim)
self: other.cross(grad, dim)
other: grad.cross(self, dim)
- name: cumprod(Tensor self, int64_t dim)
self: cumprod_backward(grad, self, dim)
- name: cumprod(Tensor self, int64_t dim, *, ScalarType dtype)
self: cumprod_backward(grad, self, dim, dtype)
- name: cumsum(Tensor self, int64_t dim)
self: cumsum_backward(grad, dim)
- name: cumsum(Tensor self, int64_t dim, *, ScalarType dtype)
self: cumsum_backward(grad, dim, self.scalar_type())
- name: conv_tbc(Tensor self, Tensor weight, Tensor bias, int64_t pad)
self, weight, bias: conv_tbc_backward(grad, self, weight, bias, pad)
- name: _ctc_loss(Tensor log_probs, Tensor targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t blank, bool zero_infinity)
log_probs: _ctc_loss_backward(grad, log_probs, targets, input_lengths, target_lengths, result0, result1, blank, zero_infinity)
- name: det(Tensor self)
self: det_backward(grad, self, result)
- name: diag(Tensor self, int64_t diagonal)
self: diag_backward(grad, self.sizes(), diagonal)
- name: diagonal(Tensor self, int64_t offset, int64_t dim1, int64_t dim2)
self: diagonal_backward(grad, self.sizes(), offset, dim1, dim2)
- name: dist(Tensor self, Tensor other, Scalar p)
self: norm_backward(grad, self - other, p, result)
other: -norm_backward(grad, self - other, p, result)
- name: div(Tensor self, Tensor other)
self: grad / other
other: -grad * self / (other * other)
- name: div(Tensor self, Scalar other)
self: grad / other
- name: dot(Tensor self, Tensor tensor)
self: grad * tensor
tensor: grad * self
- name: _fused_dropout(Tensor self, double p, Generator generator)
self: _fused_dropout_backward(grad, result1, p)
- name: eig(Tensor self, bool eigenvectors)
self: not_implemented("eig")
- name: eq_(Tensor self, Scalar other)
self: zeros_like(self)
- name: eq_(Tensor self, Tensor other)
self: zeros_like(self)
other: zeros_like(other)
- name: erf(Tensor self)
self: 2.0 / sqrt(M_PI) * exp(-(self.pow(2))) * grad
- name: erfc(Tensor self)
self: -2.0 / sqrt(M_PI) * exp(-(self.pow(2))) * grad
- name: erfinv(Tensor self)
self: 0.5 * sqrt(M_PI) * exp(self.erfinv().pow(2)) * grad
- name: exp(Tensor self)
self: grad * result
- name: expm1(Tensor self)
self: grad * (result + 1)
- name: expand(Tensor self, IntArrayRef size, *, bool implicit)
self: at::sum_to(grad, self.sizes())
- name: exponential_(Tensor self, double lambd, Generator generator)
self: zeros_like(grad)
- name: fill_(Tensor self, Scalar value)
self: zeros_like(grad)
- name: fill_(Tensor self, Tensor value)
self: zeros_like(grad)
value: grad.sum()
- name: floor(Tensor self)
self: zeros_like(grad)
- name: fmod(Tensor self, Scalar other)
self: grad
- name: fmod(Tensor self, Tensor other)
self: grad
other: 'not_implemented("fmod: other")'
- name: frac(Tensor self)
self: grad
- name: gather(Tensor self, int64_t dim, Tensor index, bool sparse_grad)
self: "sparse_grad ? at::_gather_sparse_backward(self, dim, index, grad) : at::zeros(self.sizes(), grad.options()).scatter_add_(dim, index, grad)"
- name: ge_(Tensor self, Scalar other)
self: zeros_like(self)
- name: ge_(Tensor self, Tensor other)
self: zeros_like(self)
other: zeros_like(other)
- name: gels(Tensor self, Tensor A)
self: not_implemented("gels")
A: not_implemented("gels")
- name: geometric_(Tensor self, double p, Generator generator)
self: zeros_like(grad)
- name: geqrf(Tensor self)
self: not_implemented("geqrf")
- name: ger(Tensor self, Tensor vec2)
self: grad.mv(vec2)
vec2: grad.t().mv(self)
- name: indices(Tensor self)
output_differentiability: [False]
- name: _indices(Tensor self)
output_differentiability: [False]
- name: grid_sampler_2d(Tensor input, Tensor grid, int64_t interpolation_mode, int64_t padding_mode)
input, grid: grid_sampler_2d_backward(grad, input, grid, interpolation_mode, padding_mode)
- name: grid_sampler_3d(Tensor input, Tensor grid, int64_t interpolation_mode, int64_t padding_mode)
input, grid: grid_sampler_3d_backward(grad, input, grid, interpolation_mode, padding_mode)
- name: gt_(Tensor self, Scalar other)
self: zeros_like(self)
- name: gt_(Tensor self, Tensor other)
self: zeros_like(self)
other: zeros_like(other)
- name: histc(Tensor self, int64_t bins, Scalar min, Scalar max)
self: not_implemented("histc")
- name: index(Tensor self, TensorList indices)
self: zeros_like(self).index_put_(indices, grad, true)
indices: TensorList()
- name: index_add_(Tensor self, int64_t dim, Tensor index, Tensor source)
self: grad
source: grad.index_select(dim, index)
- name: index_copy_(Tensor self, int64_t dim, Tensor index, Tensor source)
self: grad.clone().index_fill_(dim, index, 0)
source: grad.index_select(dim, index)
- name: index_fill_(Tensor self, int64_t dim, Tensor index, Scalar value)
self: grad.clone().index_fill_(dim, index, 0)
- name: index_fill_(Tensor self, int64_t dim, Tensor index, Tensor value)
self: grad.clone().index_fill_(dim, index, 0)
value: grad.index_select(dim, index).sum()
- name: index_put_(Tensor self, TensorList indices, Tensor values, bool accumulate)
self: grad.clone().index_put_(indices, zeros_like(values), accumulate)
values: grad.index(indices)
- name: index_select(Tensor self, int64_t dim, Tensor index)
self: at::zeros(self.sizes(), grad.options()).index_add_(dim, index, grad)
- name: inverse(Tensor self)
self: -at::matmul(result.transpose(-2, -1), at::matmul(grad, result.transpose(-2, -1)))
- name: kthvalue(Tensor self, int64_t k, int64_t dim, bool keepdim)
self: index_select_backward(grad, dim, indices, self.sizes(), keepdim)
- name: le_(Tensor self, Scalar other)
self: zeros_like(self)
- name: le_(Tensor self, Tensor other)
self: zeros_like(self)
other: zeros_like(other)
- name: lerp(Tensor self, Tensor end, Scalar weight)
self: grad * (1 - weight.toDouble())
end: grad * weight
- name: lerp(Tensor self, Tensor end, Tensor weight)
self: grad * (1 - weight)
end: grad * weight
- name: lgamma(Tensor self)
self: grad * digamma(self)
- name: digamma(Tensor self)
self: grad * polygamma(1, self)
- name: polygamma(int64_t n, Tensor self)
self: grad * polygamma(n + 1, self)
- name: log(Tensor self)
self: grad.div(self)
- name: log10(Tensor self)
self: grad / (self * 2.3025850929940456)
- name: log1p(Tensor self)
self: log1p_backward(grad, self)
- name: log2(Tensor self)
self: grad / (self * 0.6931471805599453)
- name: logdet(Tensor self)
self: logdet_backward(grad, self, result)
- name: log_normal_(Tensor self, double mean, double std, Generator generator)
self: zeros_like(grad)
- name: logsumexp(Tensor self, IntArrayRef dim, bool keepdim)
self: logsumexp_backward(grad, self, result, dim, keepdim)
- name: lt_(Tensor self, Scalar other)
self: zeros_like(self)
- name: lt_(Tensor self, Tensor other)
self: zeros_like(self)
other: zeros_like(other)
- name: _lu_with_info(Tensor self, bool pivot, bool check_errors)
self: not_implemented("lu_with_info")
- name: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots)
self: not_implemented("lu_solve")
- name: masked_fill_(Tensor self, Tensor mask, Scalar value)
self: grad.clone().masked_fill_(mask, 0)
- name: masked_fill_(Tensor self, Tensor mask, Tensor value)
self: grad.clone().masked_fill_(mask, 0)
value: at::where(mask, grad, zeros_like(grad)).sum()
- name: masked_scatter_(Tensor self, Tensor mask, Tensor source)
self: grad.clone().masked_fill_(mask, 0)
source: masked_scatter_backward(grad, mask, source.sizes())
- name: masked_select(Tensor self, Tensor mask)
# normally broadcasting is handled implicitly, but here, because we call an inplace
# function as an optimization and the LHS doesn't broadcast for inplace functions,
# we need to explicitly broadcast.
self: zeros_like(self.expand(at::infer_size(self.sizes(), mask.sizes()))).masked_scatter_(mask, grad)
- name: max(Tensor self, int64_t dim, bool keepdim)
self: index_select_backward(grad, dim, indices, self.sizes(), keepdim)
- name: max(Tensor self)
self: select_equals_backward(grad, self, result)
- name: max(Tensor self, Tensor other)
self: grad.clone().masked_fill_(self <= other, 0)
other: grad.clone().masked_fill_(self > other, 0)
- name: mean(Tensor self)
self: grad.expand(self.sizes()) / self.numel()
- name: mean(Tensor self, ScalarType dtype)
self: grad.expand(self.sizes()).to(self.scalar_type()) / self.numel()
- name: mean(Tensor self, IntArrayRef dim, bool keepdim)
self: sum_backward(grad, self.sizes(), dim, keepdim) / _safe_size(self.sizes(), dim)
- name: mean(Tensor self, IntArrayRef dim, ScalarType dtype)
self: sum_backward(grad, self.sizes(), dim, false).to(self.scalar_type()) / _safe_size(self.sizes(), dim)
- name: mean(Tensor self, IntArrayRef dim, bool keepdim, ScalarType dtype)
self: sum_backward(grad, self.sizes(), dim, keepdim).to(self.scalar_type()) / _safe_size(self.sizes(), dim)
- name: median(Tensor self)
self: select_equals_backward(grad, self, result)
# This is in theory incorrect in the following case:
# sorted list: [..., a, b, b, ..., b, b, c, ...] with median = b and the value
# | at middle position of the
# | list between two `b`s. E.g.,
# |
# ^the middle position
# The gradient exists and is essentially 0 in this case.
#
# In case where the middle position is at the boundary of `b` range, e.g.,
# sorted list: [..., a, b, b, ..., b, b, c, ...]
# |
# ^the middle position
# The backward implementation is correct in the sense that it returns the
# subgradient on one side.
- name: median(Tensor self, int64_t dim, bool keepdim)
self: index_select_backward(grad, dim, indices, self.sizes(), keepdim)
- name: min(Tensor self, int64_t dim, bool keepdim)
self: index_select_backward(grad, dim, indices, self.sizes(), keepdim)
- name: min(Tensor self)
self: select_equals_backward(grad, self, result)
- name: min(Tensor self, Tensor other)
self: grad.clone().masked_fill_(self >= other, 0)
other: grad.clone().masked_fill_(self < other, 0)
- name: mm(Tensor self, Tensor mat2)
self: mm_mat1_backward(grad, mat2, self, 1)
mat2: mm_mat2_backward(grad, self, mat2.sizes(), mat2.strides(), 1)
- name: mode(Tensor self, int64_t dim, bool keepdim)
self: index_select_backward(grad, dim, indices, self.sizes(), keepdim)
- name: mul(Tensor self, Tensor other)
self: grad * other
other: grad * self
- name: mul(Tensor self, Scalar other)
self: grad * other
- name: mv(Tensor self, Tensor vec)
self: grad.ger(vec)
vec: self.t().mv(grad)
- name: mvlgamma(Tensor self, int64_t p)
self: mvlgamma_backward(grad, self, p)
- name: native_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, double momentum, double eps)
input, weight, bias: native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, eps, grad_input_mask)
- name: native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor save_invstd, bool train, double eps, std::array<bool,3> output_mask)
input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, train, eps, save_mean, save_invstd, grad_input_mask)
save_mean: not_implemented("native_batch_norm_backward save_mean")
save_invstd: not_implemented("native_batch_norm_backward save_invstd")
- name: ne_(Tensor self, Scalar other)
self: zeros_like(self)
- name: ne_(Tensor self, Tensor other)
self: zeros_like(self)
other: zeros_like(other)
- name: neg(Tensor self)
self: grad.neg()
- name: norm(Tensor self, Scalar p)
self: norm_backward(grad, self, p, result)
- name: norm(Tensor self, Scalar? p, IntArrayRef dim, bool keepdim)
self: norm_backward(grad, self, p, result, dim, keepdim)
- name: norm(Tensor self, Scalar? p, ScalarType dtype)
self: norm_backward(grad, self.to(grad.scalar_type()), p, result).to(self.scalar_type())
- name: norm(Tensor self, Scalar? p, IntArrayRef dim, bool keepdim, ScalarType dtype)
self: norm_backward(grad, self.to(grad.scalar_type()), p, result, dim, keepdim).to(self.scalar_type())
- name: _pdist_forward(Tensor self, double p)
self: _pdist_backward(grad, self, p, result)
- name: _pdist_backward(Tensor grad, Tensor self, double p, Tensor pdist)
grad: not_implemented("_pdist_backward")
self: not_implemented("_pdist_backward")
pdist: not_implemented("_pdist_backward")
- name: cdist(Tensor x1, Tensor x2, double p)
x1: _cdist_backward(grad, x1, x2, p, result)
x2: _cdist_backward(grad.t().contiguous(), x2, x1, p, result.t().contiguous())
- name: _cdist_backward(Tensor grad, Tensor x1, Tensor x2, double p, Tensor cdist)
grad: not_implemented("_cdist_backward")
x1: not_implemented("_cdist_backward")
x2: not_implemented("_cdist_backward")
cdist: not_implemented("_cdist_backward")
- name: normal_(Tensor self, double mean, double std, Generator generator)
self: zeros_like(grad)
- name: normal(Tensor mean, double std, Generator generator)
mean: at::zeros(mean.sizes(), grad.options())
- name: normal(double mean, Tensor std, Generator generator)
std: at::zeros(std.sizes(), grad.options())
- name: normal(Tensor mean, Tensor std, Generator generator)
mean: at::zeros(mean.sizes(), grad.options())
std: at::zeros(std.sizes(), grad.options())
- name: orgqr(Tensor self, Tensor input2)
self: not_implemented("orgqr")
input2: not_implemented("orgqr")
- name: ormqr(Tensor self, Tensor input2, Tensor input3, bool left, bool transpose)
self: not_implemented("ormqr")
input2: not_implemented("ormqr")
input3: not_implemented("ormqr")
- name: permute(Tensor self, IntArrayRef dims)
self: permute_backwards(grad, dims)
- name: poisson(Tensor self, Generator generator)
self: zeros_like(self)
- name: pow(Tensor self, Scalar exponent)
self: pow_backward(grad, self, exponent)
- name: pow(Tensor self, Tensor exponent)
self: pow_backward_self(grad, self, exponent)
exponent: pow_backward_exponent(grad, self, exponent)
- name: pow(Scalar self, Tensor exponent)
exponent: pow_backward_exponent(grad, self, exponent)
- name: prod(Tensor self)
self: prod_backward(grad, self, result)
- name: prod(Tensor self, ScalarType dtype)
self: prod_backward(grad, self.to(grad.scalar_type()), result).to(self.scalar_type())
- name: prod(Tensor self, int64_t dim, bool keepdim)
self: prod_backward(grad, self, result, dim, keepdim)
- name: prod(Tensor self, int64_t dim, ScalarType dtype)
self: prod_backward(grad, self.to(grad.scalar_type()), result, dim, false).to(self.scalar_type())
- name: prod(Tensor self, int64_t dim, bool keepdim, ScalarType dtype)
self: prod_backward(grad, self.to(grad.scalar_type()), result, dim, keepdim).to(self.scalar_type())
- name: pstrf(Tensor self, bool upper, Scalar tol)
self: not_implemented("pstrf")
- name: put_(Tensor self, Tensor index, Tensor source, bool accumulate)
self: grad.clone().put_(index, zeros_like(source), accumulate)
source: grad.take(index)
- name: qr(Tensor self)
self: not_implemented("qr")
- name: random_(Tensor self, int64_t from, int64_t to, Generator generator)
self: zeros_like(grad)
- name: random_(Tensor self, int64_t to, Generator generator)
self: zeros_like(grad)
- name: random_(Tensor self, Generator generator)
self: zeros_like(grad)
- name: reciprocal(Tensor self)
self: -grad * result * result
- name: remainder(Tensor self, Scalar other)
self: grad
- name: remainder(Tensor self, Tensor other)
self: grad
- name: renorm(Tensor self, Scalar p, int64_t dim, Scalar maxnorm)
self: renorm_backward(grad, self, p, dim, maxnorm)
- name: repeat(Tensor self, IntArrayRef repeats)
self: repeat_backward(grad, self.dim(), repeats)
# DO NOT define a backward for reshape!
# reshape is special in that it sometimes returns a view, and sometimes not.
# Defining a backward will make codegen spit out the forward call as
# as_variable(baseType->reshape(self)),
# making it impossible (hard) to detect when it is actually a view.
# - name: reshape(Tensor self, IntArrayRef shape)
- name: round(Tensor self)
self: zeros_like(grad)
- name: rsqrt(Tensor self)
self: -0.5 * grad * result.pow(3)
- name: scatter_(Tensor self, int64_t dim, Tensor index, Tensor src)
self: grad.clone().scatter_(dim, index, 0)
src: grad.gather(dim, index)
- name: scatter_(Tensor self, int64_t dim, Tensor index, Scalar value)
self: grad.clone().scatter_(dim, index, 0)
- name: scatter_add_(Tensor self, int64_t dim, Tensor index, Tensor src)
self: grad
src: grad.gather(dim, index)
- name: select(Tensor self, int64_t dim, int64_t index)
self: select_backward(grad, self.sizes(), dim, index)
- name: sigmoid(Tensor self)
self: sigmoid_backward(grad, result)
- name: sign(Tensor self)
self: zeros_like(grad)
- name: sin(Tensor self)
self: grad * self.cos()
- name: sinh(Tensor self)
self: grad * self.cosh()
- name: slice(Tensor self, int64_t dim, int64_t start, int64_t end, int64_t step)
self: slice_backward(grad, self.sizes(), dim, start, end, step)
- name: slogdet(Tensor self)
self: slogdet_backward(grad, self, sign, logabsdet)
output_differentiability: [false, true]
- name: solve(Tensor self, Tensor A)
self: solve_backward_self(grad, self, A)
A: solve_backward_A(grad, self, A, solution)
- name: sort(Tensor self, int64_t dim, bool descending)
self: index_select_backward(grad, dim, indices, self.sizes(), true)
- name: split(Tensor self, int64_t split_size, int64_t dim)
self: split_backward(grads, split_size, dim, self.sizes(), self.options())
- name: split_with_sizes(Tensor self, IntArrayRef split_sizes, int64_t dim)
self: split_with_sizes_backward(grads, split_sizes, dim, self.sizes(), self.options())
- name: sqrt(Tensor self)
self: grad / (2 * result)
- name: squeeze(Tensor self)
self: unsqueeze_to(grad, self.sizes());
- name: squeeze(Tensor self, int64_t dim)
self: unsqueeze_to(grad, dim, self.sizes())
- name: squeeze_(Tensor self)
self: unsqueeze_to(grad, self.sizes());
- name: squeeze_(Tensor self, int64_t dim)
self: unsqueeze_to(grad, dim, self.sizes())
- name: std(Tensor self, bool unbiased)
self: var_backward(grad / (result * 2), self, unbiased)
- name: std(Tensor self, IntArrayRef dim, bool unbiased, bool keepdim)
self: var_backward(grad / (result * 2), self, dim, unbiased, keepdim)
- name: sub(Tensor self, Tensor other, *, Scalar alpha)
self: grad
other: -grad * alpha
- name: sub(Tensor self, Scalar other, *, Scalar alpha)
self: grad
- name: rsub(Tensor self, Tensor other, *, Scalar alpha)
self: -grad * alpha
other: grad
- name: rsub(Tensor self, Scalar other, *, Scalar alpha)
self: -grad * alpha
- name: sum(Tensor self)
self: grad.expand(self.sizes())
- name: sum(Tensor self, ScalarType dtype)
self: grad.expand(self.sizes()).to(self.scalar_type())
- name: sum(Tensor self, IntArrayRef dim, bool keepdim)
self: sum_backward(grad, self.sizes(), dim, keepdim)
- name: sum(Tensor self, IntArrayRef dim, ScalarType dtype)
self: sum_backward(grad, self.sizes(), dim, false).to(self.scalar_type())
- name: sum(Tensor self, IntArrayRef dim, bool keepdim, ScalarType dtype)
self: sum_backward(grad, self.sizes(), dim, keepdim).to(self.scalar_type())
- name: svd(Tensor self, bool some, bool compute_uv)
self: svd_backward(grads, self, some, compute_uv, U, S, V)
- name: symeig(Tensor self, bool eigenvectors, bool upper)
self: symeig_backward(grads, self, eigenvectors, upper, eigenvalues, eigenvectors_return)
- name: t(Tensor self)
self: grad.t()
- name: one_hot(Tensor self, int64_t num_classes)
self: non_differentiable
- name: flip(Tensor self, IntArrayRef dims)
self: grad.flip(dims)
- name: roll(Tensor self, IntArrayRef shifts, IntArrayRef dims)
self: grad.roll(fmap(reverse_list(shifts), [](int64_t i){return -i;}), reverse_list(dims))
- name: rot90(Tensor self, int64_t k, IntArrayRef dims)
self: grad.rot90(-k, dims)
- name: take(Tensor self, Tensor index)
self: zeros_like(self).put_(index, grad, true)
- name: tan(Tensor self)
self: grad * (1 + result.pow(2))
- name: tanh(Tensor self)
self: tanh_backward(grad, result)
- name: topk(Tensor self, int64_t k, int64_t dim, bool largest, bool sorted)
self: index_select_backward(grad, dim, indices, self.sizes(), true)
- name: trace(Tensor self)
self: trace_backward(grad, self.sizes())
- name: transpose(Tensor self, int64_t dim0, int64_t dim1)
self: grad.transpose(dim0, dim1)
- name: transpose_(Tensor self, int64_t dim0, int64_t dim1)
self: grad.transpose(dim0, dim1)
- name: triangular_solve(Tensor self, Tensor A, bool upper, bool transpose, bool unitriangular)
self, A: triangular_solve_backward(grads[0], grads[1], self, A, solution, upper, transpose, unitriangular, grad_input_mask)
- name: tril(Tensor self, int64_t diagonal)
self: grad.tril(diagonal)
- name: triu(Tensor self, int64_t diagonal)
self: grad.triu(diagonal)
- name: trunc(Tensor self)
self: zeros_like(grad)
- name: to_dense(Tensor self)
self: to_dense_backward(grad, self)
- name: to_mkldnn(Tensor self)
self: to_mkldnn_backward(grad, self)
- name: unfold(Tensor self, int64_t dimension, int64_t size, int64_t step)
self: unfold_backward(grad, self.sizes(), dimension, size, step)
- name: uniform_(Tensor self, double from, double to, Generator generator)
self: zeros_like(grad)
- name: _unique(Tensor self, bool sorted, bool return_inverse)
self: not_implemented("_unique")
- name: _unsafe_view(Tensor self, IntArrayRef size)
self: grad.reshape(self.sizes())
- name: unsqueeze(Tensor self, int64_t dim)
self: grad.squeeze(dim)
- name: unsqueeze_(Tensor self, int64_t dim)
self: grad.squeeze(dim)
- name: var(Tensor self, bool unbiased)
self: var_backward(grad, self, unbiased)
- name: var(Tensor self, IntArrayRef dim, bool unbiased, bool keepdim)
self: var_backward(grad, self, dim, unbiased, keepdim)
- name: view(Tensor self, IntArrayRef size)
self: grad.reshape(self.sizes())
- name: _s_where(Tensor condition, Tensor self, Tensor other)
condition: non_differentiable
self: where(condition, grad, zeros_like(grad))
other: where(condition, zeros_like(grad), grad)
# weight_norm_cuda_interface_backward does not have an explicitly defined derivative, so if we do happen
# to be running backward with create_graph=True, fall back to a backward function that uses
# differentiable ops.
- name: _weight_norm_cuda_interface(Tensor v, Tensor g, int64_t dim)
v, g: "GradMode::is_enabled() ? _weight_norm_differentiable_backward(grad.contiguous(), v, g, result1, dim) : _weight_norm_cuda_interface_backward(grad.contiguous(), v, g, result1, dim)"
- name: zero_(Tensor self)
self: zeros_like(grad)
- name: sparse_mask(Tensor self, SparseTensorRef mask)
self: not_implemented("sparse_mask")
mask: not_implemented("sparse_mask")
- name: _sparse_coo_tensor_with_dims_and_tensors(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size, Tensor indices, Tensor values, TensorOptions options)
values: sparse_constructor_values_backward(grad, indices, values.sizes())
- name: _sparse_sum(Tensor self, IntArrayRef dim)
self: at::_sparse_sum_backward(grad, self, dim)
- name: _standard_gamma(Tensor self, Generator generator)
self: grad * _standard_gamma_grad(self, result)
- name: _standard_gamma_grad(Tensor self, Tensor output)
self: not_implemented("_standard_gamma_grad")
- name: values(Tensor self)
self: at::_sparse_coo_tensor_unsafe(self.indices(), grad, self.sizes())._coalesced_(true);
# Why is _values() not differentiable?
# See NOTE [ Sparse: autograd and API ]
- name: _values(Tensor self)
output_differentiability: [False]
# NN
- name: _trilinear(Tensor i1, Tensor i2, Tensor i3, IntArrayRef expand1, IntArrayRef expand2, IntArrayRef expand3, IntArrayRef sumdim, int64_t unroll_dim)
i1, i2, i3: _trilinear_backward(grad, i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim, grad_input_mask)
- name: constant_pad_nd(Tensor self, IntArrayRef pad, Scalar value)
self: constant_pad_nd_backward(grad, pad)
- name: binary_cross_entropy(Tensor self, Tensor target, Tensor weight, int64_t reduction)
self: binary_cross_entropy_backward(grad, self, target, weight, reduction)
- name: binary_cross_entropy_with_logits(Tensor self, Tensor target, Tensor weight, Tensor pos_weight, int64_t reduction)
self: binary_cross_entropy_with_logits_backward(grad, self, target, weight, pos_weight, reduction)
target: binary_cross_entropy_with_logits_target_backward(grad, self, target, weight, pos_weight, reduction)
- name: embedding(Tensor weight, Tensor indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse)
indices: non_differentiable
weight: embedding_backward(grad, indices, weight.size(0), padding_idx, scale_grad_by_freq, sparse)
- name: embedding_dense_backward(Tensor grad_output, Tensor indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq)
grad_output: embedding_dense_double_backward(grad, indices)
indices: non_differentiable
- name: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, Tensor per_sample_weights)
indices: non_differentiable
offsets: non_differentiable
weight: _embedding_bag_backward(grad, indices, offsets, result1, result2, result3, weight.size(0), scale_grad_by_freq, mode, sparse, per_sample_weights)
per_sample_weights: _embedding_bag_per_sample_weights_backward(grad, weight, indices, offsets, result1, mode)
- name: _embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, Tensor per_sample_weights)
indices: non_differentiable
offsets: non_differentiable
offset2bag: non_differentiable
bag_size: non_differentiable
maximum_indices: non_differentiable
- name: embedding_renorm_(Tensor self, Tensor indices, double max_norm, double norm_type)
indices: non_differentiable
self: not_implemented("embedding_renorm")
- name: kl_div(Tensor self, Tensor target, int64_t reduction)
self: kl_div_backward(grad, self, target, reduction)
target: kl_div_target_backward(grad, self, target, reduction)
- name: l1_loss(Tensor self, Tensor target, int64_t reduction)
self: l1_loss_backward(grad, self, target, reduction)
- name: mse_loss(Tensor self, Tensor target, int64_t reduction)
self: mse_loss_backward(grad, self, target, reduction)
- name: multi_margin_loss(Tensor self, Tensor target, Scalar p, Scalar margin, Tensor weight, int64_t reduction)
self: multi_margin_loss_backward(grad, self, target, p, margin, weight, reduction)
- name: multilabel_margin_loss_forward(Tensor self, Tensor target, int64_t reduction)
self: multilabel_margin_loss_backward(grad, self, target, reduction, is_target)
- name: nll_loss_forward(Tensor self, Tensor target, Tensor weight, int64_t reduction, int64_t ignore_index)
self: nll_loss_backward(grad, self, target, weight, reduction, ignore_index, total_weight)
- name: nll_loss2d_forward(Tensor self, Tensor target, Tensor weight, int64_t reduction, int64_t ignore_index)
self: nll_loss2d_backward(grad, self, target, weight, reduction, ignore_index, total_weight)
- name: smooth_l1_loss(Tensor self, Tensor target, int64_t reduction)
self: smooth_l1_loss_backward(grad, self, target, reduction)
- name: soft_margin_loss(Tensor self, Tensor target, int64_t reduction)