From 9f62738e7c6be4a396aeeba209e545e03986e336 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E8=8F=9C=E8=90=9D=20=E5=8D=9C=E5=86=AC=E7=93=9C?= Date: Mon, 8 Jul 2024 11:01:12 +0800 Subject: [PATCH 1/8] Add RaspberryPi 5 CPU Overclock benchmark. (#5561) --- benchmark/README.md | 117 ++++++++++++++++++++++++++++++-------------- 1 file changed, 80 insertions(+), 37 deletions(-) diff --git a/benchmark/README.md b/benchmark/README.md index 569490f5447..006bc21507d 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -1869,10 +1869,11 @@ cooling_down = 1 vision_transformer min = 20594.51 max = 20601.53 avg = 20596.59 FastestDet min = 90.25 max = 91.00 avg = 90.64 ``` -### Raspberry Pi 5 Broadcom BCM2712, VideoCore VII Graphics Overclock to 1.1Ghz (Vulkan 1.2) + +### Raspberry Pi 5 Broadcom BCM2712 Overclock to 2.9Ghz, VideoCore VII Graphics Overclock to 1.1Ghz (Vulkan 1.2) ``` +pi@raspberrypi:~/ncnn/build/benchmark $ sudo echo "arm_freq=2900" >> /boot/firmware/config.txt pi@raspberrypi:~/ncnn/build/benchmark $ sudo echo "gpu_freq=1100" >> /boot/firmware/config.txt -pi@raspberrypi:~/ncnn/build/benchmark $ sudo echo "force_turbo=1" >> /boot/firmware/config.txt pi@raspberrypi:~/ncnn/build/benchmark $ sudo reboot pi@raspberrypi:~/ncnn/build/benchmark $ ./benchncnn 10 4 0 0 @@ -1891,41 +1892,83 @@ num_threads = 4 powersave = 0 gpu_device = 0 cooling_down = 1 - squeezenet min = 106.91 max = 106.99 avg = 106.95 - squeezenet_int8 min = 8.91 max = 9.49 avg = 9.11 - mobilenet min = 147.60 max = 147.66 avg = 147.63 - mobilenet_int8 min = 10.77 max = 36.34 avg = 14.67 - mobilenet_v2 min = 109.97 max = 110.05 avg = 110.00 - mobilenet_v3 min = 101.90 max = 102.00 avg = 101.95 - shufflenet min = 59.73 max = 60.29 avg = 59.89 - shufflenet_v2 min = 81.38 max = 81.60 avg = 81.48 - mnasnet min = 105.78 max = 105.88 avg = 105.82 - proxylessnasnet min = 108.78 max = 108.92 avg = 108.84 - efficientnet_b0 min = 168.82 max = 169.02 avg = 168.90 - efficientnetv2_b0 min = 232.37 max = 232.58 avg = 232.49 - regnety_400m min = 130.27 max = 130.41 avg = 130.34 - blazeface min = 22.14 max = 22.20 avg = 22.17 - googlenet min = 299.08 max = 299.28 avg = 299.17 - googlenet_int8 min = 29.24 max = 29.92 avg = 29.61 - resnet18 min = 304.37 max = 304.55 avg = 304.48 - resnet18_int8 min = 26.23 max = 53.80 avg = 35.61 - alexnet min = 203.85 max = 217.97 avg = 209.30 - vgg16 min = 1570.77 max = 1571.04 avg = 1570.94 - vgg16_int8 min = 129.85 max = 145.79 avg = 132.82 - resnet50 min = 753.93 max = 754.41 avg = 754.08 - resnet50_int8 min = 49.41 max = 49.84 avg = 49.64 - squeezenet_ssd min = 399.12 max = 399.55 avg = 399.30 - squeezenet_ssd_int8 min = 34.22 max = 34.89 avg = 34.54 - mobilenet_ssd min = 344.68 max = 344.90 avg = 344.79 - mobilenet_ssd_int8 min = 27.42 max = 28.16 avg = 27.74 - mobilenet_yolo min = 711.69 max = 711.76 avg = 711.72 - mobilenetv2_yolov3 min = 361.99 max = 362.11 avg = 362.05 - yolov4-tiny min = 589.25 max = 608.54 avg = 595.14 - nanodet_m min = 178.85 max = 184.93 avg = 180.18 - yolo-fastest-1.1 min = 92.28 max = 92.53 avg = 92.43 - yolo-fastestv2 min = 70.79 max = 73.38 avg = 71.19 - vision_transformer min = 18645.20 max = 18787.41 avg = 18667.17 - FastestDet min = 74.67 max = 74.77 avg = 74.71 + squeezenet min = 106.98 max = 107.05 avg = 107.02 + squeezenet_int8 min = 8.51 max = 8.83 avg = 8.65 + mobilenet min = 147.66 max = 147.71 avg = 147.68 + mobilenet_int8 min = 10.21 max = 10.54 avg = 10.37 + mobilenet_v2 min = 110.11 max = 110.23 avg = 110.18 + mobilenet_v3 min = 101.84 max = 102.03 avg = 101.92 + shufflenet min = 59.77 max = 59.84 avg = 59.80 + shufflenet_v2 min = 81.46 max = 81.60 avg = 81.51 + mnasnet min = 105.88 max = 105.98 avg = 105.94 + proxylessnasnet min = 108.82 max = 108.89 avg = 108.86 + efficientnet_b0 min = 168.79 max = 168.93 avg = 168.87 + efficientnetv2_b0 min = 232.52 max = 232.80 avg = 232.65 + regnety_400m min = 130.33 max = 130.49 avg = 130.36 + blazeface min = 22.23 max = 22.49 avg = 22.39 + googlenet min = 299.25 max = 299.37 avg = 299.31 + googlenet_int8 min = 29.21 max = 29.97 avg = 29.58 + resnet18 min = 304.47 max = 304.64 avg = 304.58 + resnet18_int8 min = 19.31 max = 20.77 avg = 20.24 + alexnet min = 203.68 max = 203.79 avg = 203.76 + vgg16 min = 1571.91 max = 1572.22 avg = 1572.06 + vgg16_int8 min = 128.46 max = 130.89 avg = 129.96 + resnet50 min = 754.16 max = 754.33 avg = 754.26 + resnet50_int8 min = 52.65 max = 53.48 avg = 53.09 + squeezenet_ssd min = 398.22 max = 398.36 avg = 398.28 + squeezenet_ssd_int8 min = 34.26 max = 34.67 avg = 34.51 + mobilenet_ssd min = 344.81 max = 344.99 avg = 344.89 + mobilenet_ssd_int8 min = 27.59 max = 28.01 avg = 27.77 + mobilenet_yolo min = 712.53 max = 712.63 avg = 712.59 + mobilenetv2_yolov3 min = 362.81 max = 363.11 avg = 362.90 + yolov4-tiny min = 589.30 max = 589.51 avg = 589.39 + nanodet_m min = 178.83 max = 178.97 avg = 178.88 + yolo-fastest-1.1 min = 92.36 max = 92.58 avg = 92.45 + yolo-fastestv2 min = 70.68 max = 70.84 avg = 70.74 + vision_transformer min = 18615.94 max = 18648.17 avg = 18633.77 + FastestDet min = 74.59 max = 74.68 avg = 74.63 + +pi@raspberrypi:~/ncnn/build/benchmark $ ./benchncnn 10 4 0 -1 +loop_count = 10 +num_threads = 4 +powersave = 0 +gpu_device = -1 +cooling_down = 1 + squeezenet min = 7.61 max = 7.76 avg = 7.70 + squeezenet_int8 min = 7.97 max = 8.68 avg = 8.23 + mobilenet min = 9.65 max = 9.91 avg = 9.80 + mobilenet_int8 min = 10.60 max = 36.93 avg = 13.29 + mobilenet_v2 min = 12.25 max = 12.64 avg = 12.40 + mobilenet_v3 min = 8.14 max = 8.26 avg = 8.20 + shufflenet min = 3.72 max = 3.82 avg = 3.77 + shufflenet_v2 min = 2.99 max = 3.10 avg = 3.05 + mnasnet min = 7.27 max = 7.46 avg = 7.37 + proxylessnasnet min = 8.39 max = 8.55 avg = 8.48 + efficientnet_b0 min = 13.15 max = 13.59 avg = 13.39 + efficientnetv2_b0 min = 14.79 max = 15.30 avg = 14.91 + regnety_400m min = 9.49 max = 9.71 avg = 9.57 + blazeface min = 1.41 max = 1.46 avg = 1.43 + googlenet min = 28.60 max = 28.87 avg = 28.73 + googlenet_int8 min = 27.09 max = 27.77 avg = 27.47 + resnet18 min = 21.47 max = 21.88 avg = 21.65 + resnet18_int8 min = 20.07 max = 20.30 avg = 20.24 + alexnet min = 22.75 max = 23.47 avg = 23.05 + vgg16 min = 154.32 max = 158.51 avg = 157.40 + vgg16_int8 min = 127.78 max = 162.60 avg = 133.21 + resnet50 min = 49.36 max = 49.86 avg = 49.63 + resnet50_int8 min = 46.44 max = 46.89 avg = 46.74 + squeezenet_ssd min = 37.31 max = 74.95 avg = 41.30 + squeezenet_ssd_int8 min = 32.62 max = 33.63 avg = 33.09 + mobilenet_ssd min = 27.40 max = 27.99 avg = 27.68 + mobilenet_ssd_int8 min = 26.70 max = 27.71 avg = 27.23 + mobilenet_yolo min = 60.25 max = 61.10 avg = 60.67 + mobilenetv2_yolov3 min = 43.51 max = 44.29 avg = 43.87 + yolov4-tiny min = 51.63 max = 52.64 avg = 52.24 + nanodet_m min = 11.89 max = 12.06 avg = 11.97 + yolo-fastest-1.1 min = 5.63 max = 5.78 avg = 5.69 + yolo-fastestv2 min = 5.34 max = 5.48 avg = 5.40 + vision_transformer min = 481.78 max = 506.72 avg = 493.05 + FastestDet min = 4.91 max = 5.14 avg = 5.01 ``` ### Raspberry Pi Zero 2 W Broadcom BCM2710A1, Cortex-A53 (ARMv8) (1.0GHz x 4) From 1225793e567f5915e8ea10664cbf2493e39b73eb Mon Sep 17 00:00:00 2001 From: inspireMeNow <74866582+inspireMeNow@users.noreply.github.com> Date: Mon, 8 Jul 2024 11:07:11 +0800 Subject: [PATCH 2/8] benchmark: add Snapdragon 765G and CVITEK SG2000 (#5555) --- benchmark/README.md | 130 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 130 insertions(+) diff --git a/benchmark/README.md b/benchmark/README.md index 006bc21507d..9bfe7ae5cff 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -1046,6 +1046,98 @@ cooling_down = 1 mobilenetv2_yolov3 min = 57.49 max = 61.15 avg = 58.74 ``` +### Qualcomm SDM765G Snapdragon 765G (Kyro 1.8GHz x 6 + Kyro 2.2GHz x 2 + Adreno 620) +``` +130|bramble:/data/local/tmp $ ./benchncnn 8 4 2 -1 1 +loop_count = 8 +num_threads = 4 +powersave = 2 +gpu_device = -1 +cooling_down = 1 + squeezenet min = 9.84 max = 11.72 avg = 10.36 + squeezenet_int8 min = 10.80 max = 11.13 avg = 10.96 + mobilenet min = 14.04 max = 14.37 avg = 14.20 + mobilenet_int8 min = 13.39 max = 13.75 avg = 13.59 + mobilenet_v2 min = 13.04 max = 13.51 avg = 13.27 + mobilenet_v3 min = 11.00 max = 13.21 avg = 12.54 + shufflenet min = 11.08 max = 11.22 avg = 11.16 + shufflenet_v2 min = 8.45 max = 8.50 avg = 8.47 + mnasnet min = 14.15 max = 14.69 avg = 14.38 + proxylessnasnet min = 14.49 max = 15.07 avg = 14.83 + efficientnet_b0 min = 28.99 max = 29.53 avg = 29.24 + efficientnetv2_b0 min = 38.92 max = 39.34 avg = 39.14 + regnety_400m min = 33.46 max = 33.81 avg = 33.62 + blazeface min = 4.22 max = 4.30 avg = 4.27 + googlenet min = 35.24 max = 36.94 avg = 35.57 + googlenet_int8 min = 45.26 max = 46.46 avg = 45.78 + resnet18 min = 33.14 max = 33.75 avg = 33.31 + resnet18_int8 min = 43.26 max = 43.50 avg = 43.35 + alexnet min = 25.40 max = 26.19 avg = 25.74 + vgg16 min = 121.39 max = 122.35 avg = 121.78 + vgg16_int8 min = 243.47 max = 249.94 avg = 245.56 + resnet50 min = 67.05 max = 70.16 avg = 68.20 + resnet50_int8 min = 76.95 max = 80.23 avg = 78.18 + squeezenet_ssd min = 32.02 max = 33.27 avg = 32.51 + squeezenet_ssd_int8 min = 36.31 max = 38.35 avg = 37.09 + mobilenet_ssd min = 32.02 max = 34.55 avg = 32.99 + mobilenet_ssd_int8 min = 32.31 max = 33.92 avg = 32.77 + mobilenet_yolo min = 99.12 max = 109.81 avg = 103.00 + mobilenetv2_yolov3 min = 59.74 max = 60.95 avg = 60.21 + yolov4-tiny min = 57.83 max = 72.15 avg = 68.75 + nanodet_m min = 22.76 max = 22.97 avg = 22.85 + yolo-fastest-1.1 min = 13.58 max = 13.93 avg = 13.80 + yolo-fastestv2 min = 12.06 max = 12.27 avg = 12.15 + vision_transformer min = 1274.67 max = 1597.52 avg = 1363.14 + FastestDet min = 9.75 max = 9.86 avg = 9.81 + +130|bramble:/data/local/tmp $ ./benchncnn 8 4 2 0 1 +[0 Adreno (TM) 620] queueC=0[3] queueG=0[3] queueT=0[3] +[0 Adreno (TM) 620] bugsbn1=1 bugbilz=0 bugcopc=0 bugihfa=0 +[0 Adreno (TM) 620] fp16-p/s/u/a=1/1/0/1 int8-p/s/u/a=1/0/0/1 +[0 Adreno (TM) 620] subgroup=64 basic/vote/ballot/shuffle=1/1/1/1 +[0 Adreno (TM) 620] fp16-8x8x16/16x8x8/16x8x16/16x16x16=0/0/0/0 +loop_count = 8 +num_threads = 4 +powersave = 2 +gpu_device = 0 +cooling_down = 1 + squeezenet min = 25.06 max = 25.80 avg = 25.53 + squeezenet_int8 min = 9.75 max = 9.82 avg = 9.78 + mobilenet min = 43.43 max = 44.04 avg = 43.71 + mobilenet_int8 min = 11.12 max = 11.59 avg = 11.34 + mobilenet_v2 min = 32.14 max = 32.58 avg = 32.40 + mobilenet_v3 min = 32.75 max = 32.98 avg = 32.87 + shufflenet min = 29.29 max = 29.63 avg = 29.40 + shufflenet_v2 min = 32.43 max = 33.18 avg = 32.69 + mnasnet min = 34.58 max = 35.24 avg = 35.00 + proxylessnasnet min = 40.61 max = 41.40 avg = 40.98 + efficientnet_b0 min = 49.44 max = 50.46 avg = 49.95 + efficientnetv2_b0 min = 185.31 max = 187.37 avg = 186.24 + regnety_400m min = 41.43 max = 42.75 avg = 41.84 + blazeface min = 13.47 max = 14.07 avg = 13.72 + googlenet min = 78.12 max = 79.06 avg = 78.56 + googlenet_int8 min = 48.73 max = 50.13 avg = 49.20 + resnet18 min = 73.61 max = 74.05 avg = 73.75 + resnet18_int8 min = 21.87 max = 22.05 avg = 21.95 + alexnet min = 128.58 max = 129.51 avg = 128.97 + vgg16 min = 437.64 max = 439.12 avg = 438.28 + vgg16_int8 min = 232.77 max = 243.06 avg = 239.54 + resnet50 min = 187.36 max = 188.47 avg = 188.01 + resnet50_int8 min = 75.79 max = 77.33 avg = 76.64 + squeezenet_ssd min = 80.68 max = 84.50 avg = 81.93 + squeezenet_ssd_int8 min = 29.88 max = 30.77 avg = 30.30 + mobilenet_ssd min = 94.77 max = 96.46 avg = 95.79 + mobilenet_ssd_int8 min = 29.03 max = 30.07 avg = 29.53 + mobilenet_yolo min = 185.97 max = 188.11 avg = 186.59 + mobilenetv2_yolov3 min = 108.43 max = 164.75 avg = 121.55 + yolov4-tiny min = 149.38 max = 158.39 avg = 153.92 + nanodet_m min = 46.73 max = 48.85 avg = 47.73 + yolo-fastest-1.1 min = 26.32 max = 26.77 avg = 26.54 + yolo-fastestv2 min = 38.87 max = 39.31 avg = 39.13 + vision_transformer min = 3392.80 max = 3397.79 avg = 3396.09 + FastestDet min = 43.05 max = 43.81 avg = 43.45 +``` + ### Qualcomm SDM660 Snapdragon 660 (Kyro260 2.2GHz x 4 + Kyro260 1.84GHz x 4 + Adreno 512) ``` lavender:/data/local/tmp/ncnnbench $ ./benchncnn 8 8 0 -1 1 @@ -7135,6 +7227,44 @@ cooling_down = 0 FastestDet min = 199.71 max = 200.38 avg = 199.90 ``` +### CVITEK SG2000 (C906, 1 GHz x 1 + 700MHz x 1) +``` +[root@milkv-duo]~/ncnn# ./benchncnn 4 1 2 -1 0 +loop_count = 4 +num_threads = 1 +powersave = 2 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 221.53 max = 229.14 avg = 225.53 + squeezenet_int8 min = 8153.49 max = 8163.26 avg = 8160.17 + mobilenet min = 329.60 max = 338.58 avg = 335.00 + mobilenet_int8 min = 12725.12 max = 12733.70 avg = 12728.52 + mobilenet_v2 min = 253.83 max = 260.60 avg = 257.20 + mobilenet_v3 min = 205.51 max = 212.72 avg = 209.26 + shufflenet min = 358.73 max = 367.05 avg = 364.52 + shufflenet_v2 min = 238.44 max = 246.05 avg = 242.09 + mnasnet min = 254.39 max = 258.26 avg = 255.63 + proxylessnasnet min = 294.99 max = 302.80 avg = 300.65 + regnety_400m min = 407.72 max = 409.69 avg = 409.03 + blazeface min = 117.08 max = 124.26 avg = 119.00 + googlenet min = 817.28 max = 824.70 avg = 820.70 + googlenet_int8 min = 18246.97 max = 18276.23 avg = 18261.11 + resnet18 min = 610.81 max = 618.87 avg = 613.91 + resnet18_int8 min = 18772.96 max = 18808.53 avg = 18786.88 + alexnet min = 568.11 max = 577.02 avg = 570.66 + squeezenet_ssd min = 890.76 max = 896.30 avg = 893.57 + squeezenet_ssd_int8 min = 31680.48 max = 31938.09 avg = 31810.68 + mobilenet_ssd min = 746.38 max = 762.07 avg = 752.19 + mobilenet_ssd_int8 min = 41140.62 max = 41540.85 avg = 41356.70 + mobilenet_yolo min = 1744.59 max = 1755.90 avg = 1750.05 + mobilenetv2_yolov3 min = 890.20 max = 897.86 avg = 895.14 + yolov4-tiny min = 1056.03 max = 1059.44 avg = 1058.21 + nanodet_m min = 547.85 max = 554.80 avg = 549.81 + yolo-fastest-1.1 min = 290.89 max = 298.31 avg = 296.24 + yolo-fastestv2 min = 188.59 max = 196.79 avg = 190.96 + FastestDet min = 196.19 max = 205.96 avg = 200.99 +``` + ### Rockchip RK3588 (Quad Core A76 2.4GHz + Quad Core A55 1.8GHz) test in ROCK5 MODEL B From 74037b49f8bf9772c2f2c71ca40a73fa546e888c Mon Sep 17 00:00:00 2001 From: UOPiceman <50815088+UOPiceman@users.noreply.github.com> Date: Mon, 8 Jul 2024 12:59:13 +0800 Subject: [PATCH 3/8] Add Axera AX630C benchmark (#5559) --- benchmark/README.md | 70 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/benchmark/README.md b/benchmark/README.md index 9bfe7ae5cff..1927acf81cd 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -8790,3 +8790,73 @@ cooling_down = 0 vision_transformer min = 153.75 max = 198.81 avg = 165.58 FastestDet min = 3.01 max = 5.01 avg = 3.29 ``` + +### AXERA AX630C (Cortex-A53 1.2GHz * 2) + +``` +# ~/ncnn/build-aarch64-linux-gnu/benchmark # ./benchncnn 4 1 0 -1 0 +loop_count = 4 +num_threads = 1 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 129.78 max = 130.30 avg = 130.09 + squeezenet_int8 min = 123.08 max = 123.48 avg = 123.22 + mobilenet min = 211.46 max = 221.68 avg = 214.14 + mobilenet_int8 min = 196.00 max = 212.73 avg = 200.23 + mobilenet_v2 min = 149.15 max = 149.21 avg = 149.17 + mobilenet_v3 min = 124.70 max = 125.54 avg = 125.08 + shufflenet min = 80.75 max = 80.88 avg = 80.81 + shufflenet_v2 min = 74.30 max = 74.50 avg = 74.37 + mnasnet min = 148.87 max = 165.85 avg = 153.26 + proxylessnasnet min = 203.05 max = 213.50 avg = 205.82 + efficientnet_b0 min = 270.39 max = 280.59 avg = 273.13 + efficientnetv2_b0 min = 302.93 max = 318.07 avg = 307.30 + regnety_400m min = 187.47 max = 187.90 avg = 187.60 + blazeface min = 22.64 max = 22.78 avg = 22.72 + googlenet min = 487.36 max = 503.50 avg = 493.93 + googlenet_int8 min = 418.16 max = 434.44 avg = 426.09 + resnet18_int8 min = 290.39 max = 301.90 avg = 293.70 + resnet50_int8 min = 888.81 max = 898.34 avg = 895.92 + squeezenet_ssd min = 320.78 max = 330.33 avg = 323.54 + squeezenet_ssd_int8 min = 281.52 max = 299.11 avg = 286.89 + mobilenet_ssd min = 435.79 max = 452.66 avg = 444.19 + mobilenet_ssd_int8 min = 394.38 max = 411.09 avg = 398.65 + mobilenet_yolo min = 955.48 max = 972.38 avg = 967.52 + mobilenetv2_yolov3 min = 519.47 max = 536.58 avg = 524.25 + yolo-fastestv2 min = 73.94 max = 74.15 avg = 74.05 + FastestDet min = 81.89 max = 82.07 avg = 81.98 + +# ~/ncnn/build-aarch64-linux-gnu/benchmark # ./benchncnn 4 2 0 -1 0 +loop_count = 4 +num_threads = 2 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 75.14 max = 88.89 avg = 79.06 + squeezenet_int8 min = 70.11 max = 85.48 avg = 74.32 + mobilenet min = 112.72 max = 124.85 avg = 115.87 + mobilenet_int8 min = 100.35 max = 100.58 avg = 100.49 + mobilenet_v2 min = 85.92 max = 86.20 avg = 86.03 + mobilenet_v3 min = 73.94 max = 74.34 avg = 74.20 + shufflenet min = 53.99 max = 66.11 avg = 57.63 + shufflenet_v2 min = 47.47 max = 47.72 avg = 47.59 + mnasnet min = 85.96 max = 86.27 avg = 86.13 + proxylessnasnet min = 111.15 max = 121.84 avg = 113.92 + efficientnet_b0 min = 149.72 max = 150.00 avg = 149.85 + efficientnetv2_b0 min = 168.84 max = 170.57 avg = 169.35 + regnety_400m min = 120.42 max = 135.50 avg = 124.26 + blazeface min = 14.27 max = 14.48 avg = 14.39 + googlenet min = 263.82 max = 274.74 avg = 266.84 + googlenet_int8 min = 226.91 max = 227.36 avg = 227.23 + resnet18_int8 min = 157.66 max = 168.11 avg = 160.57 + resnet50_int8 min = 469.84 max = 484.00 avg = 476.59 + squeezenet_ssd min = 190.23 max = 204.41 avg = 193.99 + squeezenet_ssd_int8 min = 162.73 max = 174.30 avg = 165.79 + mobilenet_ssd min = 236.26 max = 251.16 avg = 240.34 + mobilenet_ssd_int8 min = 203.22 max = 212.01 avg = 206.00 + mobilenet_yolo min = 522.45 max = 537.99 avg = 529.95 + mobilenetv2_yolov3 min = 300.33 max = 316.59 avg = 304.89 + yolo-fastestv2 min = 50.27 max = 50.62 avg = 50.43 + FastestDet min = 53.34 max = 53.64 avg = 53.51 +``` From 02327ba96f4e0a7eec038000d8687c6a0b6a33af Mon Sep 17 00:00:00 2001 From: luxincn <118981961+luxincn@users.noreply.github.com> Date: Tue, 9 Jul 2024 11:03:05 +0800 Subject: [PATCH 4/8] add esp32 build document and ci Refs #5536 (#5567) --- .github/workflows/esp32.yml | 69 +++++++++++++++++++++++++++++++ docs/how-to-build/how-to-build.md | 29 +++++++++++++ toolchains/esp32.toolchain.cmake | 16 +++++++ 3 files changed, 114 insertions(+) create mode 100644 .github/workflows/esp32.yml create mode 100644 toolchains/esp32.toolchain.cmake diff --git a/.github/workflows/esp32.yml b/.github/workflows/esp32.yml new file mode 100644 index 00000000000..ddfb4ae2022 --- /dev/null +++ b/.github/workflows/esp32.yml @@ -0,0 +1,69 @@ +name: ESP32 +on: + push: + branches: [master] + paths: + - '.github/workflows/esp32.yml' + - 'CMakeLists.txt' + - 'cmake/**' + - 'src/*' + - 'src/layer/*' + pull_request: + branches: [master] + paths: + - '.github/workflows/esp32.yml' + - 'CMakeLists.txt' + - 'cmake/**' + - 'src/*' + - 'src/layer/*' + +concurrency: + group: esp32-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + +jobs: + build: + name: ESP32 + runs-on: ubuntu-20.04 + + steps: + - uses: actions/checkout@v4 + with: + submodules: true + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: '3.8' + + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y cmake ninja-build ccache + + - name: Checkout ESP-IDF + uses: actions/checkout@v4 + with: + repository: espressif/esp-idf + path: esp-idf-install + ref: release/v5.3 + + - name: Install ESP-IDF + run: | + cd esp-idf-install + git submodule update --init --recursive + ./install.sh + + - name: Set environment and build NCNN for ESP32 + run: | + source esp-idf-install/export.sh + echo "IDF_PATH=$IDF_PATH" >> $GITHUB_ENV + echo "${IDF_PATH}/tools" >> $GITHUB_PATH + echo "${IDF_PATH}/components" >> $GITHUB_PATH + mkdir -p build-esp32 && cd build-esp32 + cmake -DCMAKE_TOOLCHAIN_FILE="../toolchains/esp32.toolchain.cmake" -DCMAKE_BUILD_TYPE=Release -DNCNN_BUILD_EXAMPLES=OFF .. + make -j 4 + make install diff --git a/docs/how-to-build/how-to-build.md b/docs/how-to-build/how-to-build.md index 742f61b2192..b423834c501 100644 --- a/docs/how-to-build/how-to-build.md +++ b/docs/how-to-build/how-to-build.md @@ -28,6 +28,7 @@ git submodule update --init - [Build for QNX](#build-for-qnx) - [Build for Nintendo 3DS Homebrew Launcher](#build-for-nintendo-3ds-homebrew-launcher) - [Build for HarmonyOS with cross-compiling](#build-for-harmonyos-with-cross-compiling) +- [Build for ESP32 with cross-compiling](#build-for-esp32-with-cross-compiling) *** @@ -885,3 +886,31 @@ ${HM_SDK}/native/build-tools/cmake/bin/cmake -DOHOS_STL=c++_static -DOHOS_ARCH=a make -j$(nproc) make install ``` + +*** + +### Build for ESP32 with cross-compiling +Download esp-idf sdk +```shell +git clone https://github.com/espressif/esp-idf +cd esp-idf +git submodule update --init --recursive +``` +Install esp-idf sdk and configure the environment +```shell +sudo sh install.sh +source export.sh +``` +Note: python>=3.8, cmake>=3.24.0 + +Build ncnn library: +```shell +mkdir build-esp32 +cd build-esp32 +cmake -DCMAKE_TOOLCHAIN_FILE=../toolchains/esp32.toolchain.cmake -DCMAKE_BUILD_TYPE=Release .. +make -j 4 +make install +``` +Note: Make sure to compile in esp-idf environment. + +The compiled ncnn library and headers can be put to the esp32 project to test. diff --git a/toolchains/esp32.toolchain.cmake b/toolchains/esp32.toolchain.cmake new file mode 100644 index 00000000000..34b27766efd --- /dev/null +++ b/toolchains/esp32.toolchain.cmake @@ -0,0 +1,16 @@ +set(CMAKE_SYSTEM_NAME freertos) +set(CMAKE_SYSTEM_PROCESSOR xtensa-esp32) + +include($ENV{IDF_PATH}/tools/cmake/toolchain-esp32.cmake) + +set(CMAKE_TRY_COMPILE_TARGET_TYPE STATIC_LIBRARY) + +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY) + +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" CACHE STRING "c flags") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" CACHE STRING "c++ flags") + +option(NCNN_BUILD_BENCHMARK "" OFF) \ No newline at end of file From 854678b5f394130c71735faa0e8436ea881f9ab2 Mon Sep 17 00:00:00 2001 From: nihui Date: Tue, 9 Jul 2024 11:26:17 +0800 Subject: [PATCH 5/8] pnnx convert onnx prelu gelu elu leakyrelu relu6 celu hardshrink hardsigmoid hardswish clip (#5572) --- tools/pnnx/src/CMakeLists.txt | 1 + tools/pnnx/src/pass_level2/F_celu.cpp | 21 +++ tools/pnnx/src/pass_level2/F_elu.cpp | 21 +++ tools/pnnx/src/pass_level2/F_gelu.cpp | 45 ++++++ tools/pnnx/src/pass_level2/F_hardshrink.cpp | 52 +++++++ tools/pnnx/src/pass_level2/F_hardsigmoid.cpp | 22 +++ tools/pnnx/src/pass_level2/F_hardswish.cpp | 68 ++++++++++ tools/pnnx/src/pass_level2/F_leaky_relu.cpp | 21 +++ tools/pnnx/src/pass_level2/F_prelu.cpp | 128 ++++++++++++++++++ tools/pnnx/src/pass_level2/torch_clamp.cpp | 27 ++++ tools/pnnx/src/pass_level5.cpp | 2 + .../src/pass_level5/fuse_static_prelu.cpp | 57 ++++++++ .../pnnx/src/pass_level5/fuse_static_prelu.h | 21 +++ tools/pnnx/src/pass_onnx.cpp | 16 +++ tools/pnnx/tests/onnx/CMakeLists.txt | 14 ++ tools/pnnx/tests/onnx/test_F_celu.py | 70 ++++++++++ tools/pnnx/tests/onnx/test_F_elu.py | 66 +++++++++ tools/pnnx/tests/onnx/test_F_gelu.py | 73 ++++++++++ tools/pnnx/tests/onnx/test_F_hardshrink.py | 66 +++++++++ tools/pnnx/tests/onnx/test_F_hardsigmoid.py | 83 ++++++++++++ tools/pnnx/tests/onnx/test_F_hardswish.py | 80 +++++++++++ tools/pnnx/tests/onnx/test_F_hardtanh.py | 66 +++++++++ tools/pnnx/tests/onnx/test_F_leaky_relu.py | 66 +++++++++ tools/pnnx/tests/onnx/test_F_prelu.py | 79 +++++++++++ tools/pnnx/tests/onnx/test_F_relu6.py | 71 ++++++++++ tools/pnnx/tests/onnx/test_nn_CELU.py | 73 ++++++++++ tools/pnnx/tests/onnx/test_nn_ELU.py | 69 ++++++++++ tools/pnnx/tests/onnx/test_nn_GELU.py | 68 ++++++++++ tools/pnnx/tests/onnx/test_nn_Hardshrink.py | 69 ++++++++++ tools/pnnx/tests/onnx/test_nn_Hardsigmoid.py | 68 ++++++++++ tools/pnnx/tests/onnx/test_nn_Hardswish.py | 68 ++++++++++ tools/pnnx/tests/onnx/test_nn_Hardtanh.py | 69 ++++++++++ tools/pnnx/tests/onnx/test_nn_LeakyReLU.py | 69 ++++++++++ tools/pnnx/tests/onnx/test_nn_PReLU.py | 77 +++++++++++ tools/pnnx/tests/onnx/test_nn_ReLU6.py | 68 ++++++++++ 35 files changed, 1934 insertions(+) create mode 100644 tools/pnnx/src/pass_level5/fuse_static_prelu.cpp create mode 100644 tools/pnnx/src/pass_level5/fuse_static_prelu.h create mode 100644 tools/pnnx/tests/onnx/test_F_celu.py create mode 100644 tools/pnnx/tests/onnx/test_F_elu.py create mode 100644 tools/pnnx/tests/onnx/test_F_gelu.py create mode 100644 tools/pnnx/tests/onnx/test_F_hardshrink.py create mode 100644 tools/pnnx/tests/onnx/test_F_hardsigmoid.py create mode 100644 tools/pnnx/tests/onnx/test_F_hardswish.py create mode 100644 tools/pnnx/tests/onnx/test_F_hardtanh.py create mode 100644 tools/pnnx/tests/onnx/test_F_leaky_relu.py create mode 100644 tools/pnnx/tests/onnx/test_F_prelu.py create mode 100644 tools/pnnx/tests/onnx/test_F_relu6.py create mode 100644 tools/pnnx/tests/onnx/test_nn_CELU.py create mode 100644 tools/pnnx/tests/onnx/test_nn_ELU.py create mode 100644 tools/pnnx/tests/onnx/test_nn_GELU.py create mode 100644 tools/pnnx/tests/onnx/test_nn_Hardshrink.py create mode 100644 tools/pnnx/tests/onnx/test_nn_Hardsigmoid.py create mode 100644 tools/pnnx/tests/onnx/test_nn_Hardswish.py create mode 100644 tools/pnnx/tests/onnx/test_nn_Hardtanh.py create mode 100644 tools/pnnx/tests/onnx/test_nn_LeakyReLU.py create mode 100644 tools/pnnx/tests/onnx/test_nn_PReLU.py create mode 100644 tools/pnnx/tests/onnx/test_nn_ReLU6.py diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 2fc9bf37757..e2fc28da9a9 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -382,6 +382,7 @@ set(pnnx_pass_level5_SRCS pass_level5/fuse_static_instancenorm.cpp pass_level5/fuse_static_layernorm.cpp pass_level5/fuse_static_linear.cpp + pass_level5/fuse_static_prelu.cpp pass_level5/normalize_einsum_equation.cpp pass_level5/unroll_rnn_op.cpp ) diff --git a/tools/pnnx/src/pass_level2/F_celu.cpp b/tools/pnnx/src/pass_level2/F_celu.cpp index 0e13fb181d6..fac8501b2fb 100644 --- a/tools/pnnx/src/pass_level2/F_celu.cpp +++ b/tools/pnnx/src/pass_level2/F_celu.cpp @@ -38,4 +38,25 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_celu, 10) +class F_celu_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +Celu op_0 1 1 input out alpha=%alpha +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.celu"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_celu_onnx, 10) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_elu.cpp b/tools/pnnx/src/pass_level2/F_elu.cpp index 667d14fa633..d14f6208b50 100644 --- a/tools/pnnx/src/pass_level2/F_elu.cpp +++ b/tools/pnnx/src/pass_level2/F_elu.cpp @@ -40,4 +40,25 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_elu, 10) +class F_elu_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +Elu op_0 1 1 input out alpha=%alpha +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.elu"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_elu_onnx, 10) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_gelu.cpp b/tools/pnnx/src/pass_level2/F_gelu.cpp index cbb91179497..977ad86e2a4 100644 --- a/tools/pnnx/src/pass_level2/F_gelu.cpp +++ b/tools/pnnx/src/pass_level2/F_gelu.cpp @@ -305,4 +305,49 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_gelu_onnx, 9) +class F_gelu_onnx_1 : public F_gelu_onnx +{ +public: + // (x * 0.5) * (1.0 + torch.erf(x / math.sqrt(2.0))) + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +10 9 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 0p5 value=%0p5 +aten::mul op_1 2 1 input 0p5 15 +prim::Constant op_2 0 1 sqrt2 value=%sqrt2 +aten::div op_3 2 1 input sqrt2 16 +aten::erf op_4 1 1 16 17 +prim::Constant op_5 0 1 one value=%1 +aten::add op_6 2 1 17 one 22 +aten::mul op_7 2 1 15 22 out +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_gelu_onnx_1, 9) + +class F_gelu_onnx_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +Gelu op_0 1 1 input out approximate=%approximate +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.gelu"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_gelu_onnx_2, 10) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_hardshrink.cpp b/tools/pnnx/src/pass_level2/F_hardshrink.cpp index 1907d9471e1..656e106f499 100644 --- a/tools/pnnx/src/pass_level2/F_hardshrink.cpp +++ b/tools/pnnx/src/pass_level2/F_hardshrink.cpp @@ -38,4 +38,56 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardshrink, 10) +static bool NearlyEqual(float a, float b, float epsilon) +{ + if (a == b) + return true; + + float diff = (float)fabs(a - b); + if (diff <= epsilon) + return true; + + // relative error + return diff < epsilon * std::max(fabs(a), fabs(b)); +} + +class F_hardshrink_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +9 8 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 a value=%lambd +aten::gt op_1 2 1 input a aa +prim::Constant op_2 0 1 b value=%lambd2 +aten::lt op_3 2 1 input b bb +aten::__or__ op_4 2 1 aa bb ab +prim::Constant op_5 0 1 zero value=0 +aten::where op_6 3 1 ab input zero out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.hardshrink"; + } + + bool match(const std::map& captured_params) const + { + float lambd = captured_params.at("lambd").f; + float lambd2 = captured_params.at("lambd2").f; + return NearlyEqual(lambd, -lambd2, 0.001); + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["lambd"] = captured_params.at("lambd"); + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardshrink_1, 10) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_hardsigmoid.cpp b/tools/pnnx/src/pass_level2/F_hardsigmoid.cpp index 3d0678287aa..a518f7e5ef2 100644 --- a/tools/pnnx/src/pass_level2/F_hardsigmoid.cpp +++ b/tools/pnnx/src/pass_level2/F_hardsigmoid.cpp @@ -65,6 +65,28 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardsigmoid_2, 9) +class F_hardsigmoid_2_1 : public F_hardsigmoid_2 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +9 8 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 410 value=3 +aten::add op_1 2 1 input 410 a +prim::Constant op_2 0 1 413 value=0 +prim::Constant op_3 0 1 414 value=6 +aten::clamp op_4 3 1 a 413 414 b +prim::Constant op_5 0 1 409 value=6 +aten::div op_6 2 1 b 409 out +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardsigmoid_2_1, 9) + class F_hardsigmoid_3 : public GraphRewriterPass { public: diff --git a/tools/pnnx/src/pass_level2/F_hardswish.cpp b/tools/pnnx/src/pass_level2/F_hardswish.cpp index e48a7e5b176..7d44efedf6b 100644 --- a/tools/pnnx/src/pass_level2/F_hardswish.cpp +++ b/tools/pnnx/src/pass_level2/F_hardswish.cpp @@ -146,6 +146,29 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardswish_4, 8) +class F_hardswish_4_1 : public F_hardswish_4 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +10 9 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 25 value=3 +aten::add op_2 2 1 input 25 a +prim::Constant op_3 0 1 48 value=0 +prim::Constant op_4 0 1 49 value=6 +aten::clamp op_5 3 1 a 48 49 b +prim::Constant op_6 0 1 50 value=6 +aten::div op_7 2 1 b 50 c +aten::mul op_8 2 1 c input out +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardswish_4_1, 8) + class F_hardswish_5 : public GraphRewriterPass { public: @@ -249,4 +272,49 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardswish_onnx, 10) +static bool NearlyEqual(float a, float b, float epsilon) +{ + if (a == b) + return true; + + float diff = (float)fabs(a - b); + if (diff <= epsilon) + return true; + + // relative error + return diff < epsilon * std::max(fabs(a), fabs(b)); +} + +class F_hardswish_onnx_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +HardSigmoid op_0 1 1 input h alpha=%alpha +aten::mul op_1 2 1 input h out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.hardswish"; + } + + bool match(const std::map& captured_params) const + { + float alpha = captured_params.at("alpha").f; + return NearlyEqual(alpha, 1.f / 6, 0.001); + } + + void write(Operator* /*op*/, const std::map& /*captured_params*/) const + { + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardswish_onnx_1, 9) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_leaky_relu.cpp b/tools/pnnx/src/pass_level2/F_leaky_relu.cpp index 8fefe33c700..04c3ad45f59 100644 --- a/tools/pnnx/src/pass_level2/F_leaky_relu.cpp +++ b/tools/pnnx/src/pass_level2/F_leaky_relu.cpp @@ -38,4 +38,25 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_leaky_relu, 10) +class F_leaky_relu_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +LeakyRelu op_0 1 1 input out alpha=%negative_slope +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.leaky_relu"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_leaky_relu_onnx, 10) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_prelu.cpp b/tools/pnnx/src/pass_level2/F_prelu.cpp index 96fe7141780..47dbb6b88f4 100644 --- a/tools/pnnx/src/pass_level2/F_prelu.cpp +++ b/tools/pnnx/src/pass_level2/F_prelu.cpp @@ -38,4 +38,132 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_prelu, 10) +class F_prelu_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + // clang-format off + // *INDENT-OFF* + + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight #weight=(?)f32 +PRelu op_0 2 1 input weight out +pnnx.Output output 1 0 out +)PNNXIR"; + + // *INDENT-ON* + // clang-format on + } + + const char* type_str() const + { + return "F.prelu"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_prelu_onnx, 10) + +class F_prelu_onnx_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +pnnx.Attribute slope 0 1 weight @data +PRelu op_0 2 1 input weight out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +pnnx.Attribute slope 0 1 weight @data=%slope.data +F.prelu prelu 2 1 input weight out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const + { + GraphRewriterPass::write(ops, captured_params, captured_attrs); + + Operator* op_slope = ops.at("slope"); + + // hack slope shape + int num_slope = op_slope->attrs["data"].shape[0]; + op_slope->attrs["data"].shape = {num_slope}; + + op_slope->outputs[0]->shape = {num_slope}; + op_slope->outputs[0]->type = op_slope->attrs["data"].type; + + Operator* op_prelu = ops.at("prelu"); + op_prelu->inputnames = {"input", "weight"}; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_prelu_onnx_1, 10) + +class F_prelu_onnx_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + // clang-format off + // *INDENT-OFF* + + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight #weight=(?)f32 +Unsqueeze uqz 1 1 weight w2 axes=%axes +PRelu op_0 2 1 input w2 out +pnnx.Output output 1 0 out +)PNNXIR"; + + // *INDENT-ON* + // clang-format on + } + + const char* type_str() const + { + return "F.prelu"; + } + + bool match(const std::map& captured_params) const + { + if (captured_params.at("axes").type == 5) + { + // 1 2 ... N + const std::vector& axes = captured_params.at("axes").ai; + for (int i = 0; i < (int)axes.size(); i++) + { + if (axes[i] != i + 1) + return false; + } + } + else + { + int axes = captured_params.at("axes").i; + if (axes != 1) + return false; + } + + return true; + } + + void write(Operator* /*op*/, const std::map& /*captured_params*/) const + { + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_prelu_onnx_2, 10) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_clamp.cpp b/tools/pnnx/src/pass_level2/torch_clamp.cpp index a04d6341b76..118b6feb333 100644 --- a/tools/pnnx/src/pass_level2/torch_clamp.cpp +++ b/tools/pnnx/src/pass_level2/torch_clamp.cpp @@ -60,4 +60,31 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_clamp_onnx, 20) +class torch_clamp_onnx_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +aten::clamp op_0 1 1 input out max=%max +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.clamp"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["min"] = 0.f; + op->params["max"] = captured_params.at("max"); + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_clamp_onnx_1, 20) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5.cpp b/tools/pnnx/src/pass_level5.cpp index 4d483267d9d..4903f185117 100644 --- a/tools/pnnx/src/pass_level5.cpp +++ b/tools/pnnx/src/pass_level5.cpp @@ -59,6 +59,7 @@ #include "pass_level5/fuse_static_instancenorm.h" #include "pass_level5/fuse_static_layernorm.h" #include "pass_level5/fuse_static_linear.h" +#include "pass_level5/fuse_static_prelu.h" #include "pass_level5/normalize_einsum_equation.h" #include "pass_level4/dead_code_elimination.h" #include "pass_level4/canonicalize.h" @@ -106,6 +107,7 @@ void pass_level5(Graph& g, const std::set& foldable_constants, cons fuse_static_convtranspose(g); fuse_static_linear(g); fuse_static_embedding(g); + fuse_static_prelu(g); fuse_conv1d_batchnorm1d(g); fuse_conv2d_batchnorm2d(g); diff --git a/tools/pnnx/src/pass_level5/fuse_static_prelu.cpp b/tools/pnnx/src/pass_level5/fuse_static_prelu.cpp new file mode 100644 index 00000000000..7bdf9a8b0bf --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_static_prelu.cpp @@ -0,0 +1,57 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_static_prelu.h" + +#include "pass_level2.h" + +#include +#include + +namespace pnnx { + +class fuse_static_Fprelu_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @data=(%num_parameters)f32 +F.prelu op_0 2 1 input weight out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.PReLU prelu 1 1 input out num_parameters=%num_parameters @weight=%op_weight.data +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +void fuse_static_prelu(Graph& graph) +{ + fuse_static_Fprelu_pass a; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_static_prelu.h b/tools/pnnx/src/pass_level5/fuse_static_prelu.h new file mode 100644 index 00000000000..778662ce834 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_static_prelu.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_static_prelu(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx.cpp b/tools/pnnx/src/pass_onnx.cpp index 0fe00da4f2a..dd9194111fc 100644 --- a/tools/pnnx/src/pass_onnx.cpp +++ b/tools/pnnx/src/pass_onnx.cpp @@ -812,12 +812,28 @@ void pass_onnx(const onnx::ModelProto& model, Graph& pnnx_graph) bool is_attr_weight = false; { + if (sim_op_type == "BatchNormalization" && (j == 1 || j == 2 || j == 3 || j == 4)) + is_attr_weight = true; if (sim_op_type == "Conv" && (j == 1 || j == 2)) is_attr_weight = true; if (sim_op_type == "ConvTranspose" && (j == 1 || j == 2)) is_attr_weight = true; + if (sim_op_type == "Gather" && j == 0) + is_attr_weight = true; + if (sim_op_type == "GroupNormalization" && (j == 1 || j == 2)) + is_attr_weight = true; + if (sim_op_type == "GRU" && (j == 1 || j == 2 || j == 3 || j == 5)) + is_attr_weight = true; if (sim_op_type == "InstanceNormalization" && (j == 1 || j == 2)) is_attr_weight = true; + if (sim_op_type == "LayerNormalization" && (j == 1 || j == 2)) + is_attr_weight = true; + if (sim_op_type == "LSTM" && (j == 1 || j == 2 || j == 3 || j == 5 || j == 6)) + is_attr_weight = true; + if (sim_op_type == "PRelu" && j == 1) + is_attr_weight = true; + if (sim_op_type == "RNN" && (j == 1 || j == 2 || j == 3 || j == 5)) + is_attr_weight = true; } int64_t numel = 1; diff --git a/tools/pnnx/tests/onnx/CMakeLists.txt b/tools/pnnx/tests/onnx/CMakeLists.txt index 6b9c6db7553..8ed4b5e480a 100644 --- a/tools/pnnx/tests/onnx/CMakeLists.txt +++ b/tools/pnnx/tests/onnx/CMakeLists.txt @@ -9,23 +9,30 @@ pnnx_onnx_add_test(F_avg_pool1d) pnnx_onnx_add_test(F_avg_pool2d) pnnx_onnx_add_test(F_avg_pool3d) pnnx_onnx_add_test(F_batch_norm) +pnnx_onnx_add_test(F_celu) pnnx_onnx_add_test(F_conv_transpose1d) pnnx_onnx_add_test(F_conv_transpose2d) pnnx_onnx_add_test(F_conv_transpose3d) pnnx_onnx_add_test(F_conv1d) pnnx_onnx_add_test(F_conv2d) pnnx_onnx_add_test(F_conv3d) +pnnx_onnx_add_test(F_elu) +pnnx_onnx_add_test(F_gelu) # pnnx_onnx_add_test(F_group_norm) # pnnx_onnx_add_test(F_instance_norm) pnnx_onnx_add_test(F_interpolate) pnnx_onnx_add_test(F_layer_norm) +pnnx_onnx_add_test(F_leaky_relu) pnnx_onnx_add_test(F_linear) pnnx_onnx_add_test(F_local_response_norm) pnnx_onnx_add_test(F_max_pool1d) pnnx_onnx_add_test(F_max_pool2d) pnnx_onnx_add_test(F_max_pool3d) pnnx_onnx_add_test(F_pad) +pnnx_onnx_add_test(F_prelu) pnnx_onnx_add_test(F_relu) +pnnx_onnx_add_test(F_relu6) +# pnnx_onnx_add_test(F_scaled_dot_product_attention) pnnx_onnx_add_test(F_sigmoid) pnnx_onnx_add_test(F_softmax) pnnx_onnx_add_test(F_upsample_bilinear) @@ -38,6 +45,7 @@ pnnx_onnx_add_test(nn_AvgPool3d) pnnx_onnx_add_test(nn_BatchNorm1d) pnnx_onnx_add_test(nn_BatchNorm2d) pnnx_onnx_add_test(nn_BatchNorm3d) +pnnx_onnx_add_test(nn_CELU) pnnx_onnx_add_test(nn_ConstantPad1d) pnnx_onnx_add_test(nn_ConstantPad2d) pnnx_onnx_add_test(nn_ConstantPad3d) @@ -47,21 +55,27 @@ pnnx_onnx_add_test(nn_Conv3d) pnnx_onnx_add_test(nn_ConvTranspose1d) pnnx_onnx_add_test(nn_ConvTranspose2d) pnnx_onnx_add_test(nn_ConvTranspose3d) +pnnx_onnx_add_test(nn_ELU) +pnnx_onnx_add_test(nn_GELU) pnnx_onnx_add_test(nn_GroupNorm) pnnx_onnx_add_test(nn_GRU) pnnx_onnx_add_test(nn_InstanceNorm1d) pnnx_onnx_add_test(nn_InstanceNorm2d) pnnx_onnx_add_test(nn_InstanceNorm3d) pnnx_onnx_add_test(nn_LayerNorm) +pnnx_onnx_add_test(nn_LeakyReLU) pnnx_onnx_add_test(nn_Linear) pnnx_onnx_add_test(nn_LocalResponseNorm) pnnx_onnx_add_test(nn_LSTM) pnnx_onnx_add_test(nn_MaxPool1d) pnnx_onnx_add_test(nn_MaxPool2d) pnnx_onnx_add_test(nn_MaxPool3d) +# pnnx_onnx_add_test(nn_MultiheadAttention) +pnnx_onnx_add_test(nn_PReLU) pnnx_onnx_add_test(nn_ReflectionPad1d) pnnx_onnx_add_test(nn_ReflectionPad2d) pnnx_onnx_add_test(nn_ReLU) +pnnx_onnx_add_test(nn_ReLU6) pnnx_onnx_add_test(nn_ReplicationPad1d) pnnx_onnx_add_test(nn_ReplicationPad2d) pnnx_onnx_add_test(nn_ReplicationPad3d) diff --git a/tools/pnnx/tests/onnx/test_F_celu.py b/tools/pnnx/tests/onnx/test_F_celu.py new file mode 100644 index 00000000000..dd474e030d9 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_F_celu.py @@ -0,0 +1,70 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = F.celu(x) + y = F.celu(y, 0.8) + z = F.celu(z, 0.5) + w = F.celu(w, 2) + return x, y, z, w + +def test(): + if version.parse(torch.__version__) < version.parse('1.12'): + return True + + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_F_celu.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_F_celu.onnx inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_celu_pnnx + b = test_F_celu_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_F_elu.py b/tools/pnnx/tests/onnx/test_F_elu.py new file mode 100644 index 00000000000..5126ea81a33 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_F_elu.py @@ -0,0 +1,66 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = F.elu(x) + y = F.elu(y, 1.2) + z = F.elu(z, -0.6) + w = F.elu(w, 0) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_F_elu.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_F_elu.onnx inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_elu_pnnx + b = test_F_elu_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_F_gelu.py b/tools/pnnx/tests/onnx/test_F_gelu.py new file mode 100644 index 00000000000..f93d4848f8d --- /dev/null +++ b/tools/pnnx/tests/onnx/test_F_gelu.py @@ -0,0 +1,73 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +def gelu_forward_0(x): + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + +def gelu_forward_1(x): + return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = F.gelu(x) + y = F.gelu(y) + z = gelu_forward_0(z) + w = gelu_forward_1(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_F_gelu.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_F_gelu.onnx inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_gelu_pnnx + b = test_F_gelu_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-3, 1e-3): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_F_hardshrink.py b/tools/pnnx/tests/onnx/test_F_hardshrink.py new file mode 100644 index 00000000000..52d44ff6c15 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_F_hardshrink.py @@ -0,0 +1,66 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = F.hardshrink(x) + y = F.hardshrink(y, 0.1) + z = F.hardshrink(z, 0.22) + w = F.hardshrink(w, 0) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_F_hardshrink.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_F_hardshrink.onnx inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_hardshrink_pnnx + b = test_F_hardshrink_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_F_hardsigmoid.py b/tools/pnnx/tests/onnx/test_F_hardsigmoid.py new file mode 100644 index 00000000000..53c3adaefa5 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_F_hardsigmoid.py @@ -0,0 +1,83 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def hardsigmoid_forward_0(x): + return F.relu6(x + 3., True) / 6. + +def hardsigmoid_forward_1(x): + return x.add_(3.).clamp_(0., 6.).div_(6.) + +class h_sigmoid(nn.Module): + def __init__(self, inplace=True): + super(h_sigmoid, self).__init__() + self.relu = nn.ReLU6(inplace=inplace) + + def forward(self, x): + return self.relu(x + 3) / 6 + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.h_sigmoid = h_sigmoid(); + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = F.hardsigmoid(x) + y = F.hardsigmoid(y) + z = self.h_sigmoid(z) + w = hardsigmoid_forward_0(w) + w = hardsigmoid_forward_1(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_F_hardsigmoid.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_F_hardsigmoid.onnx inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_hardsigmoid_pnnx + b = test_F_hardsigmoid_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_F_hardswish.py b/tools/pnnx/tests/onnx/test_F_hardswish.py new file mode 100644 index 00000000000..cf671d4c611 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_F_hardswish.py @@ -0,0 +1,80 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def hardswish_forward_0(x): + return x * F.hardsigmoid(x) + +def hardswish_forward_1(x): + return x * F.hardtanh(x + 3, 0., 6.) / 6. + +def hardswish_forward_2(x): + out = F.relu6(x + 3., True) / 6. + return out * x + +def hardswish_forward_3(x): + return x * F.relu6(x + 3, inplace=True) / 6 + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = F.hardswish(x) + y = hardswish_forward_0(y) + z = hardswish_forward_1(z) + w = hardswish_forward_2(w) + w = hardswish_forward_3(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_F_hardswish.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_F_hardswish.onnx inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_hardswish_pnnx + b = test_F_hardswish_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_F_hardtanh.py b/tools/pnnx/tests/onnx/test_F_hardtanh.py new file mode 100644 index 00000000000..5043703d907 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_F_hardtanh.py @@ -0,0 +1,66 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = F.hardtanh(x) + y = F.hardtanh(y, -1, 1) + z = F.hardtanh(z, -0.1, 0.1) + w = F.hardtanh(w, 0.1, 0.3) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_F_hardtanh.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_F_hardtanh.onnx inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_hardtanh_pnnx + b = test_F_hardtanh_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_F_leaky_relu.py b/tools/pnnx/tests/onnx/test_F_leaky_relu.py new file mode 100644 index 00000000000..bdffaa40ef6 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_F_leaky_relu.py @@ -0,0 +1,66 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = F.leaky_relu(x) + y = F.leaky_relu(y, 0.1) + z = F.leaky_relu(z, -0.22) + w = F.leaky_relu(w, 0) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_F_leaky_relu.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_F_leaky_relu.onnx inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_leaky_relu_pnnx + b = test_F_leaky_relu_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_F_prelu.py b/tools/pnnx/tests/onnx/test_F_prelu.py new file mode 100644 index 00000000000..9e0a9103763 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_F_prelu.py @@ -0,0 +1,79 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.w4 = nn.Parameter(torch.rand(16)) + self.w5 = nn.Parameter(torch.rand(2)) + self.w6 = nn.Parameter(torch.rand(3)) + self.w7 = nn.Parameter(torch.rand(1)) + + def forward(self, x, y, z, w, w0, w1, w2, w3): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = F.prelu(x, w0) + x = F.prelu(x, self.w4) + y = F.prelu(y, w1) + y = F.prelu(y, self.w5) + z = F.prelu(z, w2) + z = F.prelu(z, self.w6) + w = F.prelu(w, w3) + w = F.prelu(w, self.w7) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + w0 = torch.rand(16) + w1 = torch.rand(2) + w2 = torch.rand(3) + w3 = torch.rand(1) + + a = net(x, y, z, w, w0, w1, w2, w3) + + # export onnx + torch.onnx.export(net, (x, y, z, w, w0, w1, w2, w3), "test_F_prelu.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_F_prelu.onnx inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11],[16],[2],[3],[1]") + + # pnnx inference + import test_F_prelu_pnnx + b = test_F_prelu_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_F_relu6.py b/tools/pnnx/tests/onnx/test_F_relu6.py new file mode 100644 index 00000000000..4a64fafc20b --- /dev/null +++ b/tools/pnnx/tests/onnx/test_F_relu6.py @@ -0,0 +1,71 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = F.relu6(x) + y = F.relu6(y) + z = F.relu6(z) + w = F.relu6(w) + return x, y, z, w + +def test(): + # torch-1.9 failed to export onnx for relu6 + if version.parse(torch.__version__) >= version.parse('1.9') and version.parse(torch.__version__) < version.parse('1.10'): + return True + + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_F_relu6.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_F_relu6.onnx inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_relu6_pnnx + b = test_F_relu6_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_nn_CELU.py b/tools/pnnx/tests/onnx/test_nn_CELU.py new file mode 100644 index 00000000000..d1c2463298f --- /dev/null +++ b/tools/pnnx/tests/onnx/test_nn_CELU.py @@ -0,0 +1,73 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.CELU() + self.act_1 = nn.CELU(alpha=2.0) + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = self.act_0(x) + y = self.act_0(y) + z = self.act_1(z) + w = self.act_1(w) + return x, y, z, w + +def test(): + if version.parse(torch.__version__) < version.parse('1.12'): + return True + + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_nn_CELU.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_nn_CELU.onnx inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_CELU_pnnx + b = test_nn_CELU_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_nn_ELU.py b/tools/pnnx/tests/onnx/test_nn_ELU.py new file mode 100644 index 00000000000..c55d829200b --- /dev/null +++ b/tools/pnnx/tests/onnx/test_nn_ELU.py @@ -0,0 +1,69 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.ELU() + self.act_1 = nn.ELU(alpha=1.3) + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = self.act_0(x) + y = self.act_0(y) + z = self.act_1(z) + w = self.act_1(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_nn_ELU.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_nn_ELU.onnx inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_ELU_pnnx + b = test_nn_ELU_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_nn_GELU.py b/tools/pnnx/tests/onnx/test_nn_GELU.py new file mode 100644 index 00000000000..ebf85e08e44 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_nn_GELU.py @@ -0,0 +1,68 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.GELU() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + w = self.act_0(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_nn_GELU.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_nn_GELU.onnx inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_GELU_pnnx + b = test_nn_GELU_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-3, 1e-3): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_nn_Hardshrink.py b/tools/pnnx/tests/onnx/test_nn_Hardshrink.py new file mode 100644 index 00000000000..c24e48f5bcb --- /dev/null +++ b/tools/pnnx/tests/onnx/test_nn_Hardshrink.py @@ -0,0 +1,69 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Hardshrink() + self.act_1 = nn.Hardshrink(lambd=0.3) + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = self.act_0(x) + y = self.act_0(y) + z = self.act_1(z) + w = self.act_1(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_nn_Hardshrink.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_nn_Hardshrink.onnx inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_Hardshrink_pnnx + b = test_nn_Hardshrink_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_nn_Hardsigmoid.py b/tools/pnnx/tests/onnx/test_nn_Hardsigmoid.py new file mode 100644 index 00000000000..35d02bef202 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_nn_Hardsigmoid.py @@ -0,0 +1,68 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Hardsigmoid() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + w = self.act_0(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_nn_Hardsigmoid.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_nn_Hardsigmoid.onnx inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_Hardsigmoid_pnnx + b = test_nn_Hardsigmoid_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_nn_Hardswish.py b/tools/pnnx/tests/onnx/test_nn_Hardswish.py new file mode 100644 index 00000000000..2f313507ffc --- /dev/null +++ b/tools/pnnx/tests/onnx/test_nn_Hardswish.py @@ -0,0 +1,68 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Hardswish() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + w = self.act_0(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_nn_Hardswish.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_nn_Hardswish.onnx inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_Hardswish_pnnx + b = test_nn_Hardswish_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_nn_Hardtanh.py b/tools/pnnx/tests/onnx/test_nn_Hardtanh.py new file mode 100644 index 00000000000..b376bf748db --- /dev/null +++ b/tools/pnnx/tests/onnx/test_nn_Hardtanh.py @@ -0,0 +1,69 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Hardtanh() + self.act_1 = nn.Hardtanh(-0.2, 0.2) + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = self.act_0(x) + y = self.act_0(y) + z = self.act_1(z) + w = self.act_1(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_nn_Hardtanh.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_nn_Hardtanh.onnx inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_Hardtanh_pnnx + b = test_nn_Hardtanh_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_nn_LeakyReLU.py b/tools/pnnx/tests/onnx/test_nn_LeakyReLU.py new file mode 100644 index 00000000000..fc096290461 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_nn_LeakyReLU.py @@ -0,0 +1,69 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.LeakyReLU() + self.act_1 = nn.LeakyReLU(negative_slope=-0.24) + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = self.act_0(x) + y = self.act_0(y) + z = self.act_1(z) + w = self.act_1(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_nn_LeakyReLU.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_nn_LeakyReLU.onnx inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_LeakyReLU_pnnx + b = test_nn_LeakyReLU_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_nn_PReLU.py b/tools/pnnx/tests/onnx/test_nn_PReLU.py new file mode 100644 index 00000000000..306a83140a6 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_nn_PReLU.py @@ -0,0 +1,77 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.prelu_0 = nn.PReLU(num_parameters=12) + self.prelu_1 = nn.PReLU(num_parameters=1, init=0.12) + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + + x = self.prelu_0(x) + x = self.prelu_1(x) + + y = self.prelu_0(y) + y = self.prelu_1(y) + + z = self.prelu_0(z) + z = self.prelu_1(z) + + w = self.prelu_0(w) + w = self.prelu_1(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_nn_PReLU.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_nn_PReLU.onnx inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_PReLU_pnnx + b = test_nn_PReLU_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_nn_ReLU6.py b/tools/pnnx/tests/onnx/test_nn_ReLU6.py new file mode 100644 index 00000000000..bdfc654e3eb --- /dev/null +++ b/tools/pnnx/tests/onnx/test_nn_ReLU6.py @@ -0,0 +1,68 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.ReLU6() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + w = self.act_0(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_nn_ReLU6.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_nn_ReLU6.onnx inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_ReLU6_pnnx + b = test_nn_ReLU6_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) From c59885aeac6cec0dbfa010efc0b5c25bed5208b7 Mon Sep 17 00:00:00 2001 From: nihui Date: Thu, 11 Jul 2024 15:53:27 +0800 Subject: [PATCH 6/8] pnnx convert onnx multiheadattention (#5575) * pnnx convert onnx multiheadattention * onnx reducemean reducesum * reducemax reducemin reduceprod * mask buggy torch * avoid shadow output --- tools/pnnx/src/pass_level2.cpp | 25 +- tools/pnnx/src/pass_level2/F_hardswish.cpp | 26 + tools/pnnx/src/pass_level2/torch_max.cpp | 49 ++ tools/pnnx/src/pass_level2/torch_mean.cpp | 109 +---- tools/pnnx/src/pass_level2/torch_min.cpp | 49 ++ tools/pnnx/src/pass_level2/torch_prod.cpp | 49 ++ tools/pnnx/src/pass_level2/torch_sum.cpp | 49 ++ .../pass_level5/fuse_multiheadattention.cpp | 454 ++++++++++++++++++ .../pass_onnx/fuse_constant_as_attribute.cpp | 4 + tools/pnnx/tests/onnx/CMakeLists.txt | 10 +- tools/pnnx/tests/onnx/test_F_hardshrink.py | 4 + tools/pnnx/tests/onnx/test_F_hardsigmoid.py | 4 + tools/pnnx/tests/onnx/test_F_hardswish.py | 4 + tools/pnnx/tests/onnx/test_nn_Hardshrink.py | 4 + tools/pnnx/tests/onnx/test_nn_Hardsigmoid.py | 4 + .../tests/onnx/test_nn_MultiheadAttention.py | 138 ++++++ 16 files changed, 893 insertions(+), 89 deletions(-) create mode 100644 tools/pnnx/tests/onnx/test_nn_MultiheadAttention.py diff --git a/tools/pnnx/src/pass_level2.cpp b/tools/pnnx/src/pass_level2.cpp index 107613d3dea..bc7e51b8d5d 100644 --- a/tools/pnnx/src/pass_level2.cpp +++ b/tools/pnnx/src/pass_level2.cpp @@ -738,7 +738,7 @@ static bool match_operator(const Operator* a, const Operator* b, std::map& matched_operators, std::map& matched_inputs, std::map& captured_params, std::map& captured_attrs) +static bool match(const Operator* anchor, const Operator* pattern, std::map& matched_operators, std::map& matched_inputs, std::map& matched_outputs, std::map& captured_params, std::map& captured_attrs) { if (!match_operator(anchor, pattern, captured_params, captured_attrs)) return false; @@ -746,7 +746,17 @@ static bool match(const Operator* anchor, const Operator* pattern, std::mapoutputs.size(); i++) { if (pattern->outputs[i]->consumers.size() == 1 && pattern->outputs[i]->consumers[0]->type == "pnnx.Output") + { + if (matched_outputs.find(pattern->outputs[i]->name) == matched_outputs.end()) + { + matched_outputs[pattern->outputs[i]->name] = anchor->outputs[i]; + } + else if (matched_outputs[pattern->outputs[i]->name] != anchor->outputs[i]) + { + return false; + } continue; + } if (anchor->outputs[i]->consumers.size() != pattern->outputs[i]->consumers.size()) return false; @@ -773,7 +783,7 @@ static bool match(const Operator* anchor, const Operator* pattern, std::map matched_operators2; std::map matched_inputs2; + std::map matched_outputs2; std::map captured_params2; std::map captured_attrs2; - if (!match(anchor, pattern2, matched_operators2, matched_inputs2, captured_params2, captured_attrs2)) + if (!match(anchor, pattern2, matched_operators2, matched_inputs2, matched_outputs2, captured_params2, captured_attrs2)) continue; bool submatch_matched = true; @@ -872,6 +883,13 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde matched_inputs[x.first] = x.second; } } + for (auto x : matched_outputs2) + { + if (matched_outputs.find(x.first) == matched_outputs.end()) + { + matched_outputs[x.first] = x.second; + } + } for (auto x : captured_params2) { captured_params[x.first] = x.second; @@ -882,7 +900,6 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde } // match ! - matched_outputs[pattern->inputs[i]->name] = anchor->outputs[i]; break; } diff --git a/tools/pnnx/src/pass_level2/F_hardswish.cpp b/tools/pnnx/src/pass_level2/F_hardswish.cpp index 7d44efedf6b..caa724f55a7 100644 --- a/tools/pnnx/src/pass_level2/F_hardswish.cpp +++ b/tools/pnnx/src/pass_level2/F_hardswish.cpp @@ -317,4 +317,30 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardswish_onnx_1, 9) +class F_hardswish_onnx_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 7 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 20 value=3 +aten::add op_1 2 1 input 20 8 +aten::clamp op_2 1 1 8 9 max=6 min=0 +prim::Constant op_3 0 1 23 value=6 +aten::div op_4 2 1 9 23 10 +aten::mul op_5 2 1 input 10 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.hardswish"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardswish_onnx_2, 9) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_max.cpp b/tools/pnnx/src/pass_level2/torch_max.cpp index 3448b5b939f..68479b85d5b 100644 --- a/tools/pnnx/src/pass_level2/torch_max.cpp +++ b/tools/pnnx/src/pass_level2/torch_max.cpp @@ -60,4 +60,53 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_max_1, 20) +class torch_max_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +ReduceMax op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.max"; + } + + void write(Operator* op, const std::map& captured_params) const + { + if (captured_params.find("op_0.axes") != captured_params.end()) + { + op->params["dim"] = captured_params.at("op_0.axes"); + } + else + { + // reduce all + const int input_rank = (int)op->inputs[0]->shape.size(); + std::vector dim(input_rank); + for (int i = 0; i < input_rank; i++) + { + dim[i] = i; + } + op->params["dim"] = dim; + } + + if (captured_params.find("op_0.keepdims") != captured_params.end()) + { + op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false; + } + else + { + op->params["keepdim"] = true; + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_max_onnx, 20) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_mean.cpp b/tools/pnnx/src/pass_level2/torch_mean.cpp index 18944fac345..39f4423243d 100644 --- a/tools/pnnx/src/pass_level2/torch_mean.cpp +++ b/tools/pnnx/src/pass_level2/torch_mean.cpp @@ -107,7 +107,7 @@ class torch_mean_onnx : public GraphRewriterPass return R"PNNXIR(7767517 3 2 pnnx.Input input 0 1 input -ReduceMean op_0 1 1 input out axes=%axes keepdims=%keepdims +ReduceMean op_0 1 1 input out %*=%* pnnx.Output output 1 0 out )PNNXIR"; } @@ -119,92 +119,33 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - op->params["dim"] = captured_params.at("axes"); - op->params["keepdim"] = captured_params.at("keepdims").i ? true : false; - } -}; - -REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_mean_onnx, 20) - -class torch_mean_onnx_1 : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -3 2 -pnnx.Input input 0 1 input -ReduceMean op_0 1 1 input out axes=%axes -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "torch.mean"; - } - - void write(Operator* op, const std::map& captured_params) const - { - op->params["dim"] = captured_params.at("axes"); - op->params["keepdim"] = true; - } -}; - -REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_mean_onnx_1, 20) - -class torch_mean_onnx_2 : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -4 3 -pnnx.Input input_0 0 1 input -pnnx.Input input_1 0 1 dim -ReduceMean op_0 2 1 input dim out keepdims=%keepdims noop_with_empty_axes=0 -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "torch.mean"; - } - - void write(Operator* op, const std::map& captured_params) const - { - op->params["keepdim"] = captured_params.at("keepdims").i ? true : false; - } -}; - -REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_mean_onnx_2, 20) - -class torch_mean_onnx_3 : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -3 2 -pnnx.Input input 0 1 input -ReduceMean op_0 1 1 input out axes=%axes keepdims=%keepdims noop_with_empty_axes=0 -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "torch.mean"; - } + if (captured_params.find("op_0.axes") != captured_params.end()) + { + op->params["dim"] = captured_params.at("op_0.axes"); + } + else + { + // reduce all + const int input_rank = (int)op->inputs[0]->shape.size(); + std::vector dim(input_rank); + for (int i = 0; i < input_rank; i++) + { + dim[i] = i; + } + op->params["dim"] = dim; + } - void write(Operator* op, const std::map& captured_params) const - { - op->params["dim"] = captured_params.at("axes"); - op->params["keepdim"] = captured_params.at("keepdims").i ? true : false; + if (captured_params.find("op_0.keepdims") != captured_params.end()) + { + op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false; + } + else + { + op->params["keepdim"] = true; + } } }; -REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_mean_onnx_3, 20) +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_mean_onnx, 20) } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_min.cpp b/tools/pnnx/src/pass_level2/torch_min.cpp index 119b442e11d..c5e48bbc64b 100644 --- a/tools/pnnx/src/pass_level2/torch_min.cpp +++ b/tools/pnnx/src/pass_level2/torch_min.cpp @@ -60,4 +60,53 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_min_1, 20) +class torch_min_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +ReduceMin op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.min"; + } + + void write(Operator* op, const std::map& captured_params) const + { + if (captured_params.find("op_0.axes") != captured_params.end()) + { + op->params["dim"] = captured_params.at("op_0.axes"); + } + else + { + // reduce all + const int input_rank = (int)op->inputs[0]->shape.size(); + std::vector dim(input_rank); + for (int i = 0; i < input_rank; i++) + { + dim[i] = i; + } + op->params["dim"] = dim; + } + + if (captured_params.find("op_0.keepdims") != captured_params.end()) + { + op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false; + } + else + { + op->params["keepdim"] = true; + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_min_onnx, 20) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_prod.cpp b/tools/pnnx/src/pass_level2/torch_prod.cpp index bd3e49b8cb1..7f15c2ba88a 100644 --- a/tools/pnnx/src/pass_level2/torch_prod.cpp +++ b/tools/pnnx/src/pass_level2/torch_prod.cpp @@ -40,4 +40,53 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_prod, 20) +class torch_prod_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +ReduceProd op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.prod"; + } + + void write(Operator* op, const std::map& captured_params) const + { + if (captured_params.find("op_0.axes") != captured_params.end()) + { + op->params["dim"] = captured_params.at("op_0.axes"); + } + else + { + // reduce all + const int input_rank = (int)op->inputs[0]->shape.size(); + std::vector dim(input_rank); + for (int i = 0; i < input_rank; i++) + { + dim[i] = i; + } + op->params["dim"] = dim; + } + + if (captured_params.find("op_0.keepdims") != captured_params.end()) + { + op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false; + } + else + { + op->params["keepdim"] = true; + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_prod_onnx, 20) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_sum.cpp b/tools/pnnx/src/pass_level2/torch_sum.cpp index 730ffcce3b8..51803d5d01e 100644 --- a/tools/pnnx/src/pass_level2/torch_sum.cpp +++ b/tools/pnnx/src/pass_level2/torch_sum.cpp @@ -62,4 +62,53 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_sum_1, 20) +class torch_sum_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +ReduceSum op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.sum"; + } + + void write(Operator* op, const std::map& captured_params) const + { + if (captured_params.find("op_0.axes") != captured_params.end()) + { + op->params["dim"] = captured_params.at("op_0.axes"); + } + else + { + // reduce all + const int input_rank = (int)op->inputs[0]->shape.size(); + std::vector dim(input_rank); + for (int i = 0; i < input_rank; i++) + { + dim[i] = i; + } + op->params["dim"] = dim; + } + + if (captured_params.find("op_0.keepdims") != captured_params.end()) + { + op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false; + } + else + { + op->params["keepdim"] = true; + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_sum_onnx, 20) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp b/tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp index 55661366c59..2a9f3b837b1 100644 --- a/tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp +++ b/tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp @@ -1574,6 +1574,444 @@ pnnx.Output output 1 0 out } }; +class fuse_multiheadattention_pass_onnx : public fuse_multiheadattention_pass_qkv +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +23 22 +pnnx.Input input_q 0 1 query +pnnx.Input input_k 0 1 key +pnnx.Input input_v 0 1 value +nn.Linear op_0 1 1 query 10 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 key 11 bias=%kbias in_features=%kdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 value 12 bias=%vbias in_features=%vdim out_features=%embed_dim @bias @weight +Tensor.reshape op_3 1 1 10 13 shape=(%qsize,%num_heads,%feat_per_head) +Tensor.reshape op_4 1 1 11 15 shape=(%kvsize,%num_heads,%feat_per_head) +Tensor.reshape op_5 1 1 12 16 shape=(%kvsize,%num_heads,%feat_per_head) +Tensor.permute op_6 1 1 13 14 dims=(1,0,2) +Tensor.permute op_7 1 1 15 19 dims=(1,2,0) +Tensor.permute op_8 1 1 16 17 dims=(1,0,2) +pnnx.Expression op_9 1 1 14 18 expr=mul(@0,%inv_sqrt_embed_dim_per_head) +torch.matmul op_10 2 1 18 19 20 +F.softmax softmax 1 1 20 21 dim=%softmax_dim +torch.matmul op_12 2 1 21 17 22 +Tensor.permute op_13 1 1 22 23 dims=(1,0,2) +Tensor.reshape op_14 1 1 23 24 shape=(%qsize,%embed_dim) +nn.Linear out_proj 1 1 24 25 bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_16 1 1 25 out shape=(%qsize,%batch,%embed_dim) +Tensor.reshape op_17 1 1 21 27 shape=(%batch,%num_heads,%qsize,%kvsize) +torch.mean op_18 1 1 27 outweight dim=(1) keepdim=False +pnnx.Output output 2 0 out outweight +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +5 5 +pnnx.Input input_0 0 1 query +pnnx.Input input_1 0 1 key +pnnx.Input input_2 0 1 value +nn.MultiheadAttention attention 3 2 query key value out outweight embed_dim=%embed_dim kdim=%kdim vdim=%vdim num_heads=%num_heads batch_first=False add_zero_attn=False add_bias_kv=False +pnnx.Output output 2 0 out outweight +)PNNXIR"; + } +}; + +class fuse_multiheadattention_pass_onnx_1 : public fuse_multiheadattention_pass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +26 25 +pnnx.Input input_q 0 1 input +nn.Linear op_0 1 1 input 14 bias=%qkvbias in_features=%embed_dim out_features=%qkv_out_features @bias @weight +Tensor.reshape op_1 1 1 14 15 shape=(%batch,%size,1,3,%embed_dim) +Tensor.permute op_2 1 1 15 16 dims=(3,1,2,0,4) +torch.squeeze op_3 1 1 16 17 dim=3 +torch.unbind op_4 1 3 17 18 19 20 dim=0 +Tensor.reshape op_5 1 1 18 21 shape=(%size,%num_heads,%feat_per_head) +Tensor.reshape op_6 1 1 19 23 shape=(%size,%num_heads,%feat_per_head) +Tensor.reshape op_7 1 1 20 25 shape=(%size,%num_heads,%feat_per_head) +Tensor.permute op_8 1 1 21 22 dims=(1,0,2) +Tensor.permute op_9 1 1 23 24 dims=(1,0,2) +Tensor.permute op_10 1 1 25 26 dims=(1,0,2) +Tensor.reshape op_11 1 1 22 27 shape=(%batch,%num_heads,%size,%feat_per_head) +Tensor.reshape op_12 1 1 24 28 shape=(%batch,%num_heads,%size,%feat_per_head) +Tensor.reshape op_13 1 1 26 29 shape=(%batch,%num_heads,%size,%feat_per_head) +Tensor.permute op_14 1 1 28 30 dims=(0,1,3,2) +pnnx.Expression op_15 1 1 27 31 expr=mul(@0,%sqrt_inv_sqrt_embed_dim_per_head) +pnnx.Expression op_16 1 1 30 32 expr=mul(@0,%sqrt_inv_sqrt_embed_dim_per_head) +torch.matmul op_17 2 1 31 32 33 +F.softmax softmax 1 1 33 34 dim=%softmax_dim +torch.matmul op_19 2 1 34 29 35 +Tensor.permute op_20 1 1 35 36 dims=(2,0,1,3) +Tensor.reshape op_21 1 1 36 37 shape=(%size,%embed_dim) +nn.Linear out_proj 1 1 37 38 bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_23 1 1 38 out shape=(%size,%batch,%embed_dim) +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.MultiheadAttention attention 1 1 input out embed_dim=%embed_dim kdim=%embed_dim vdim=%embed_dim num_heads=%num_heads batch_first=False add_zero_attn=False add_bias_kv=False +pnnx.Output output 1 0 out +)PNNXIR"; + } + + bool match(const std::map& matched_operators, const std::map& captured_params, const std::map& /*captured_attrs*/) const + { + const int embed_dim = captured_params.at("embed_dim").i; + const int qkv_out_features = captured_params.at("qkv_out_features").i; + const int num_heads = captured_params.at("num_heads").i; + const int feat_per_head = captured_params.at("feat_per_head").i; + const float sqrt_inv_sqrt_embed_dim_per_head = captured_params.at("sqrt_inv_sqrt_embed_dim_per_head").f; + const int softmax_dim = captured_params.at("softmax_dim").i; + + if (qkv_out_features != embed_dim * 3) + return false; + + if (embed_dim != num_heads * feat_per_head) + return false; + + if (!NearlyEqual(sqrt_inv_sqrt_embed_dim_per_head, sqrt(1.f / sqrt(feat_per_head)), 0.001)) + return false; + + int softmax_input_rank = (int)matched_operators.at("softmax")->inputs[0]->shape.size(); + if (softmax_dim != -1 && softmax_dim != softmax_input_rank - 1) + return false; + + return true; + } +}; + +class fuse_multiheadattention_pass_onnx_1_1 : public fuse_multiheadattention_pass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +21 20 +pnnx.Input input_q 0 1 input +nn.Linear op_0 1 1 input 33 bias=%qkvbias in_features=%embed_dim out_features=%qkv_out_features @bias @weight +Tensor.reshape op_1 1 1 33 34 shape=(%batch,%size,1,3,%embed_dim) +Tensor.permute op_2 1 1 34 35 dims=(3,1,2,0,4) +torch.squeeze op_3 1 1 35 36 dim=3 +torch.unbind op_4 1 3 36 37 38 39 dim=0 +Tensor.reshape op_5 1 1 37 40 shape=(%size,%num_heads,%feat_per_head) +Tensor.reshape op_6 1 1 38 42 shape=(%size,%num_heads,%feat_per_head) +Tensor.reshape op_7 1 1 39 43 shape=(%size,%num_heads,%feat_per_head) +Tensor.permute op_8 1 1 40 41 dims=(1,0,2) +Tensor.permute op_9 1 1 42 46 dims=(1,2,0) +Tensor.permute op_10 1 1 43 44 dims=(1,0,2) +pnnx.Expression op_11 1 1 41 45 expr=mul(@0,%inv_sqrt_embed_dim_per_head) +torch.matmul op_12 2 1 45 46 47 +F.softmax softmax 1 1 47 48 dim=%softmax_dim +torch.matmul op_14 2 1 48 44 49 +Tensor.permute op_15 1 1 49 50 dims=(1,0,2) +Tensor.reshape op_16 1 1 50 51 shape=(%size,%embed_dim) +nn.Linear out_proj 1 1 51 52 bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_18 1 1 52 out shape=(%size,%batch,%embed_dim) +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.MultiheadAttention attention 1 1 input out embed_dim=%embed_dim kdim=%embed_dim vdim=%embed_dim num_heads=%num_heads batch_first=False add_zero_attn=False add_bias_kv=False +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +class fuse_multiheadattention_pass_onnx_2 : public fuse_multiheadattention_pass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +24 23 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 attn_mask +nn.Linear op_0 1 1 input 15 bias=%qkvbias in_features=%embed_dim out_features=%qkv_out_features @bias @weight +Tensor.reshape op_1 1 1 15 16 shape=(%batch,%size,1,3,%embed_dim) +Tensor.permute op_2 1 1 16 17 dims=(3,1,2,0,4) +torch.squeeze op_3 1 1 17 18 dim=3 +torch.unbind op_4 1 3 18 19 20 21 dim=0 +Tensor.reshape op_5 1 1 19 23 shape=(%size,%num_heads,%feat_per_head) +Tensor.reshape op_7 1 1 20 25 shape=(%size,%num_heads,%feat_per_head) +Tensor.reshape op_8 1 1 21 26 shape=(%size,%num_heads,%feat_per_head) +Tensor.permute op_6 1 1 23 24 dims=(1,0,2) +Tensor.permute op_11 1 1 25 29 dims=(1,2,0) +Tensor.permute op_9 1 1 26 27 dims=(1,0,2) +pnnx.Expression op_10 1 1 24 28 expr=mul(@0,%inv_sqrt_embed_dim_per_head) +torch.matmul op_12 2 1 28 29 30 +torch.unsqueeze op_13 1 1 attn_mask 22 dim=0 +pnnx.Expression op_14 2 1 30 22 31 expr=add(@0,@1) +F.softmax softmax 1 1 31 32 dim=%softmax_dim +torch.matmul op_16 2 1 32 27 33 +Tensor.permute op_17 1 1 33 34 dims=(1,0,2) +Tensor.reshape op_18 1 1 34 35 shape=(%size,%embed_dim) +nn.Linear out_proj 1 1 35 36 bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_20 1 1 36 out shape=(%size,%batch,%embed_dim) +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 attn_mask +nn.MultiheadAttention attention 2 1 input attn_mask out embed_dim=%embed_dim kdim=%embed_dim vdim=%embed_dim num_heads=%num_heads batch_first=False add_zero_attn=False add_bias_kv=False $attn_mask=attn_mask +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +class fuse_multiheadattention_pass_onnx_2_1 : public fuse_multiheadattention_pass_onnx_2 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +24 23 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 attn_mask +nn.Linear op_0 1 1 input 15 bias=%qkvbias in_features=%embed_dim out_features=%qkv_out_features @bias @weight +Tensor.reshape op_1 1 1 15 16 shape=(%batch,%size,1,3,%embed_dim) +Tensor.permute op_2 1 1 16 17 dims=(3,1,2,0,4) +torch.squeeze op_3 1 1 17 18 dim=3 +torch.unbind op_4 1 3 18 19 20 21 dim=0 +Tensor.reshape op_5 1 1 19 23 shape=(%size,%num_heads,%feat_per_head) +Tensor.reshape op_7 1 1 20 25 shape=(%size,%num_heads,%feat_per_head) +Tensor.reshape op_8 1 1 21 26 shape=(%size,%num_heads,%feat_per_head) +Tensor.permute op_6 1 1 23 24 dims=(1,0,2) +Tensor.permute op_11 1 1 25 29 dims=(1,2,0) +Tensor.permute op_9 1 1 26 27 dims=(1,0,2) +pnnx.Expression op_10 1 1 24 28 expr=mul(@0,%inv_sqrt_embed_dim_per_head) +torch.matmul op_12 2 1 28 29 30 +torch.unsqueeze op_13 1 1 attn_mask 22 dim=0 +pnnx.Expression op_14 2 1 30 22 31 expr=add(@0,@1) +F.softmax softmax 1 1 31 32 dim=%softmax_dim +torch.matmul op_16 2 1 32 27 33 +Tensor.permute op_17 1 1 33 34 dims=(1,0,2) +Tensor.reshape op_18 1 1 34 35 shape=(%size,%embed_dim) +nn.Linear out_proj 1 1 35 36 bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_20 1 1 36 out shape=(%size,%batch,%embed_dim) +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +class fuse_multiheadattention_pass_onnx_3 : public fuse_multiheadattention_pass_qkv +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +25 24 +pnnx.Input input_q 0 1 query +pnnx.Input input_kv 0 1 kv +nn.Linear op_0 1 1 query 14 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 kv 15 bias=%kvbias in_features=%kvdim out_features=%kv_embed_dim @bias @weight +Tensor.reshape op_2 1 1 15 16 shape=(%batch,%kvsize,1,2,%embed_dim) +Tensor.permute op_3 1 1 16 17 dims=(3,1,2,0,4) +torch.squeeze op_4 1 1 17 18 dim=3 +torch.unbind op_5 1 2 18 19 20 dim=0 +Tensor.reshape op_6 1 1 14 21 shape=(%qsize,%num_heads,%feat_per_head) +Tensor.reshape op_7 1 1 19 23 shape=(%kvsize,%num_heads,%feat_per_head) +Tensor.reshape op_8 1 1 20 24 shape=(%kvsize,%num_heads,%feat_per_head) +Tensor.permute op_9 1 1 21 22 dims=(1,0,2) +Tensor.permute op_10 1 1 24 25 dims=(1,0,2) +Tensor.permute op_11 1 1 23 27 dims=(1,2,0) +pnnx.Expression op_12 1 1 22 26 expr=mul(@0,%inv_sqrt_embed_dim_per_head) +torch.matmul op_13 2 1 26 27 28 +F.softmax softmax 1 1 28 29 dim=%softmax_dim +torch.matmul op_15 2 1 29 25 30 +Tensor.permute op_16 1 1 30 31 dims=(1,0,2) +Tensor.reshape op_17 1 1 31 32 shape=(%qsize,%embed_dim) +nn.Linear out_proj 1 1 32 33 bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_19 1 1 33 out shape=(%qsize,1,%embed_dim) +Tensor.reshape op_20 1 1 29 35 shape=(1,%num_heads,%qsize,%kvsize) +torch.mean op_21 1 1 35 outweight dim=(1) keepdim=False +pnnx.Output output 2 0 out outweight +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +4 4 +pnnx.Input input_0 0 1 query +pnnx.Input input_1 0 1 kv +nn.MultiheadAttention attention 2 2 query kv out outweight embed_dim=%embed_dim kdim=%kvdim vdim=%kvdim num_heads=%num_heads batch_first=False add_zero_attn=False add_bias_kv=False +pnnx.Output output 2 0 out outweight +)PNNXIR"; + } + + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const + { + GraphRewriterPass::write(ops, captured_params, captured_attrs); + + Operator* op = ops.at("attention"); + + const int embed_dim = captured_params.at("embed_dim").i; + const bool qbias = captured_params.at("qbias").b; + const bool kvbias = captured_params.at("kvbias").b; + const bool outbias = captured_params.at("outbias").b; + const bool bias = qbias || kvbias || outbias; + + op->params["bias"] = bias; + + op->attrs["in_proj_weight"] = captured_attrs.at("op_0.weight") + captured_attrs.at("op_1.weight"); + + op->attrs["out_proj.weight"] = captured_attrs.at("out_proj.weight"); + + if (bias) + { + op->attrs["in_proj_bias"] = Attribute(); + op->attrs["in_proj_bias"].type = op->attrs["in_proj_weight"].type; + op->attrs["in_proj_bias"].shape = {embed_dim * 3}; + // combine qkv bias + std::vector in_proj_bias(embed_dim * 3); + { + float* in_proj_bias_ptr = (float*)in_proj_bias.data(); + if (qbias) + { + auto qb = captured_attrs.at("op_0.bias").get_float32_data(); + memcpy(in_proj_bias_ptr, (const void*)qb.data(), embed_dim * sizeof(float)); + } + else + { + memset(in_proj_bias_ptr, 0, embed_dim * sizeof(float)); + } + in_proj_bias_ptr += embed_dim; + if (kvbias) + { + auto kvb = captured_attrs.at("op_1.bias").get_float32_data(); + memcpy(in_proj_bias_ptr, (const void*)kvb.data(), embed_dim * 2 * sizeof(float)); + } + else + { + memset(in_proj_bias_ptr, 0, embed_dim * 2 * sizeof(float)); + } + } + op->attrs["in_proj_bias"].set_float32_data(in_proj_bias); + + if (outbias) + { + op->attrs["out_proj.bias"] = captured_attrs.at("out_proj.bias"); + } + else + { + // init bias as zero + op->attrs["out_proj.bias"] = Attribute(); + op->attrs["out_proj.bias"].type = op->attrs["out_proj.weight"].type; + op->attrs["out_proj.bias"].shape = {embed_dim}; + op->attrs["out_proj.bias"].set_float32_data(std::vector(embed_dim, 0.f)); + } + } + } +}; + +class fuse_multiheadattention_pass_onnx_4 : public fuse_multiheadattention_pass_qkv +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +26 25 +pnnx.Input input_q 0 1 query +pnnx.Input input_k 0 1 key +pnnx.Input input_v 0 1 value +pnnx.Input input_3 0 1 attn_mask +nn.Linear op_0 1 1 query 20 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 key 21 bias=%kbias in_features=%kdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 value 22 bias=%vbias in_features=%vdim out_features=%embed_dim @bias @weight +Tensor.reshape op_3 1 1 20 24 shape=(%qsize,%num_heads,%feat_per_head) +Tensor.reshape op_4 1 1 21 26 shape=(%kvsize,%num_heads,%feat_per_head) +Tensor.reshape op_5 1 1 22 27 shape=(%kvsize,%num_heads,%feat_per_head) +Tensor.permute op_6 1 1 24 25 dims=(1,0,2) +Tensor.permute op_7 1 1 26 30 dims=(1,2,0) +Tensor.permute op_8 1 1 27 28 dims=(1,0,2) +pnnx.Expression op_9 1 1 25 29 expr=mul(@0,%inv_sqrt_embed_dim_per_head) +torch.matmul op_10 2 1 29 30 31 +torch.unsqueeze op_11 1 1 attn_mask 23 dim=0 +pnnx.Expression op_12 2 1 31 23 32 expr=add(@0,@1) +F.softmax softmax 1 1 32 33 dim=%softmax_dim +torch.matmul op_14 2 1 33 28 34 +Tensor.permute op_15 1 1 34 35 dims=(1,0,2) +Tensor.reshape op_16 1 1 35 36 shape=(%qsize,%embed_dim) +nn.Linear out_proj 1 1 36 37 bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_18 1 1 37 out shape=(%qsize,%batch,%embed_dim) +Tensor.reshape op_19 1 1 33 39 shape=(%batch,%num_heads,%qsize,%kvsize) +torch.mean op_20 1 1 39 outweight dim=(1) keepdim=False +pnnx.Output output 2 0 out outweight +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +6 6 +pnnx.Input input_0 0 1 query +pnnx.Input input_1 0 1 key +pnnx.Input input_2 0 1 value +pnnx.Input input_3 0 1 attn_mask +nn.MultiheadAttention attention 4 2 query key value attn_mask out outweight embed_dim=%embed_dim kdim=%kdim vdim=%vdim num_heads=%num_heads batch_first=False add_zero_attn=False add_bias_kv=False $attn_mask=attn_mask +pnnx.Output output 2 0 out outweight +)PNNXIR"; + } +}; + +class fuse_multiheadattention_pass_onnx_4_1 : public fuse_multiheadattention_pass_onnx_4 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +25 24 +pnnx.Input input_q 0 1 query +pnnx.Input input_k 0 1 key +pnnx.Input input_v 0 1 value +pnnx.Input input_3 0 1 attn_mask +nn.Linear op_0 1 1 query 22 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 key 23 bias=%kbias in_features=%kdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 value 24 bias=%vbias in_features=%vdim out_features=%embed_dim @bias @weight +Tensor.reshape op_3 1 1 22 25 shape=(%qsize,%num_heads,%feat_per_head) +Tensor.reshape op_4 1 1 23 27 shape=(%kvsize,%num_heads,%feat_per_head) +Tensor.reshape op_5 1 1 24 28 shape=(%kvsize,%num_heads,%feat_per_head) +Tensor.permute op_6 1 1 25 26 dims=(1,0,2) +Tensor.permute op_7 1 1 28 29 dims=(1,0,2) +Tensor.permute op_8 1 1 27 31 dims=(1,2,0) +pnnx.Expression op_9 1 1 26 30 expr=mul(@0,%inv_sqrt_embed_dim_per_head) +torch.matmul op_10 2 1 30 31 32 +pnnx.Expression op_11 2 1 32 attn_mask 33 expr=add(@0,@1) +F.softmax softmax 1 1 33 34 dim=%softmax_dim +torch.matmul op_13 2 1 34 29 35 +Tensor.permute op_14 1 1 35 36 dims=(1,0,2) +Tensor.reshape op_15 1 1 36 37 shape=(%qsize,%embed_dim) +nn.Linear out_proj 1 1 37 38 bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_16 1 1 38 out shape=(%qsize,%batch,%embed_dim) +Tensor.reshape op_18 1 1 34 40 shape=(%batch,%num_heads,%qsize,%kvsize) +torch.mean op_19 1 1 40 outweight dim=(1) keepdim=False +pnnx.Output output 2 0 out outweight +)PNNXIR"; + } +}; + void fuse_multiheadattention(Graph& graph) { #if TORCH_VERSION_MAJOR >= 2 || (TORCH_VERSION_MAJOR >= 1 && TORCH_VERSION_MINOR >= 9) @@ -1606,6 +2044,14 @@ void fuse_multiheadattention(Graph& graph) fuse_multiheadattention_pass_17_1 p1; fuse_multiheadattention_pass_18 q; fuse_multiheadattention_pass_18_1 q1; + + fuse_multiheadattention_pass_onnx onnx0; + fuse_multiheadattention_pass_onnx_1 onnx1; + fuse_multiheadattention_pass_onnx_1_1 onnx1a; + fuse_multiheadattention_pass_onnx_2 onnx2; + fuse_multiheadattention_pass_onnx_3 onnx3; + fuse_multiheadattention_pass_onnx_4 onnx4; + fuse_multiheadattention_pass_onnx_4_1 onnx4a; int opindex = 0; pnnx_graph_rewrite(graph, &a, opindex); @@ -1637,6 +2083,14 @@ void fuse_multiheadattention(Graph& graph) pnnx_graph_rewrite(graph, &p1, opindex); pnnx_graph_rewrite(graph, &q, opindex); pnnx_graph_rewrite(graph, &q1, opindex); + + pnnx_graph_rewrite(graph, &onnx0, opindex); + pnnx_graph_rewrite(graph, &onnx1, opindex); + pnnx_graph_rewrite(graph, &onnx1a, opindex); + pnnx_graph_rewrite(graph, &onnx2, opindex); + pnnx_graph_rewrite(graph, &onnx3, opindex); + pnnx_graph_rewrite(graph, &onnx4, opindex); + pnnx_graph_rewrite(graph, &onnx4a, opindex); #endif } diff --git a/tools/pnnx/src/pass_onnx/fuse_constant_as_attribute.cpp b/tools/pnnx/src/pass_onnx/fuse_constant_as_attribute.cpp index 18268d0f0fb..a3021d33c90 100644 --- a/tools/pnnx/src/pass_onnx/fuse_constant_as_attribute.cpp +++ b/tools/pnnx/src/pass_onnx/fuse_constant_as_attribute.cpp @@ -36,7 +36,11 @@ static constant_as_attribute caas[] = { {"If", 0, "cond"}, {"Pad", 1, "pads"}, {"Pad", 2, "value"}, + {"ReduceMax", 1, "axes"}, {"ReduceMean", 1, "axes"}, + {"ReduceMin", 1, "axes"}, + {"ReduceProd", 1, "axes"}, + {"ReduceSum", 1, "axes"}, {"Reshape", 1, "shape"}, {"Resize", 2, "scales"}, {"Resize", 3, "sizes"}, diff --git a/tools/pnnx/tests/onnx/CMakeLists.txt b/tools/pnnx/tests/onnx/CMakeLists.txt index 8ed4b5e480a..12d816cd8e2 100644 --- a/tools/pnnx/tests/onnx/CMakeLists.txt +++ b/tools/pnnx/tests/onnx/CMakeLists.txt @@ -19,6 +19,10 @@ pnnx_onnx_add_test(F_conv3d) pnnx_onnx_add_test(F_elu) pnnx_onnx_add_test(F_gelu) # pnnx_onnx_add_test(F_group_norm) +pnnx_onnx_add_test(F_hardshrink) +pnnx_onnx_add_test(F_hardsigmoid) +pnnx_onnx_add_test(F_hardswish) +pnnx_onnx_add_test(F_hardtanh) # pnnx_onnx_add_test(F_instance_norm) pnnx_onnx_add_test(F_interpolate) pnnx_onnx_add_test(F_layer_norm) @@ -59,6 +63,10 @@ pnnx_onnx_add_test(nn_ELU) pnnx_onnx_add_test(nn_GELU) pnnx_onnx_add_test(nn_GroupNorm) pnnx_onnx_add_test(nn_GRU) +pnnx_onnx_add_test(nn_Hardshrink) +pnnx_onnx_add_test(nn_Hardsigmoid) +pnnx_onnx_add_test(nn_Hardswish) +pnnx_onnx_add_test(nn_Hardtanh) pnnx_onnx_add_test(nn_InstanceNorm1d) pnnx_onnx_add_test(nn_InstanceNorm2d) pnnx_onnx_add_test(nn_InstanceNorm3d) @@ -70,7 +78,7 @@ pnnx_onnx_add_test(nn_LSTM) pnnx_onnx_add_test(nn_MaxPool1d) pnnx_onnx_add_test(nn_MaxPool2d) pnnx_onnx_add_test(nn_MaxPool3d) -# pnnx_onnx_add_test(nn_MultiheadAttention) +pnnx_onnx_add_test(nn_MultiheadAttention) pnnx_onnx_add_test(nn_PReLU) pnnx_onnx_add_test(nn_ReflectionPad1d) pnnx_onnx_add_test(nn_ReflectionPad2d) diff --git a/tools/pnnx/tests/onnx/test_F_hardshrink.py b/tools/pnnx/tests/onnx/test_F_hardshrink.py index 52d44ff6c15..00dbdd6fd86 100644 --- a/tools/pnnx/tests/onnx/test_F_hardshrink.py +++ b/tools/pnnx/tests/onnx/test_F_hardshrink.py @@ -15,6 +15,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from packaging import version class Model(nn.Module): def __init__(self): @@ -32,6 +33,9 @@ def forward(self, x, y, z, w): return x, y, z, w def test(): + if version.parse(torch.__version__) < version.parse('1.11'): + return True + net = Model() net.eval() diff --git a/tools/pnnx/tests/onnx/test_F_hardsigmoid.py b/tools/pnnx/tests/onnx/test_F_hardsigmoid.py index 53c3adaefa5..a0f8c7654d3 100644 --- a/tools/pnnx/tests/onnx/test_F_hardsigmoid.py +++ b/tools/pnnx/tests/onnx/test_F_hardsigmoid.py @@ -15,6 +15,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from packaging import version def hardsigmoid_forward_0(x): return F.relu6(x + 3., True) / 6. @@ -49,6 +50,9 @@ def forward(self, x, y, z, w): return x, y, z, w def test(): + if version.parse(torch.__version__) < version.parse('1.10'): + return True + net = Model() net.eval() diff --git a/tools/pnnx/tests/onnx/test_F_hardswish.py b/tools/pnnx/tests/onnx/test_F_hardswish.py index cf671d4c611..78ada9b3482 100644 --- a/tools/pnnx/tests/onnx/test_F_hardswish.py +++ b/tools/pnnx/tests/onnx/test_F_hardswish.py @@ -15,6 +15,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from packaging import version def hardswish_forward_0(x): return x * F.hardsigmoid(x) @@ -46,6 +47,9 @@ def forward(self, x, y, z, w): return x, y, z, w def test(): + if version.parse(torch.__version__) < version.parse('1.10'): + return True + net = Model() net.eval() diff --git a/tools/pnnx/tests/onnx/test_nn_Hardshrink.py b/tools/pnnx/tests/onnx/test_nn_Hardshrink.py index c24e48f5bcb..38dd67c464f 100644 --- a/tools/pnnx/tests/onnx/test_nn_Hardshrink.py +++ b/tools/pnnx/tests/onnx/test_nn_Hardshrink.py @@ -15,6 +15,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from packaging import version class Model(nn.Module): def __init__(self): @@ -35,6 +36,9 @@ def forward(self, x, y, z, w): return x, y, z, w def test(): + if version.parse(torch.__version__) < version.parse('1.11'): + return True + net = Model() net.eval() diff --git a/tools/pnnx/tests/onnx/test_nn_Hardsigmoid.py b/tools/pnnx/tests/onnx/test_nn_Hardsigmoid.py index 35d02bef202..43af7471411 100644 --- a/tools/pnnx/tests/onnx/test_nn_Hardsigmoid.py +++ b/tools/pnnx/tests/onnx/test_nn_Hardsigmoid.py @@ -15,6 +15,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from packaging import version class Model(nn.Module): def __init__(self): @@ -34,6 +35,9 @@ def forward(self, x, y, z, w): return x, y, z, w def test(): + if version.parse(torch.__version__) < version.parse('1.9'): + return True + net = Model() net.eval() diff --git a/tools/pnnx/tests/onnx/test_nn_MultiheadAttention.py b/tools/pnnx/tests/onnx/test_nn_MultiheadAttention.py new file mode 100644 index 00000000000..9d6cc263ab2 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_nn_MultiheadAttention.py @@ -0,0 +1,138 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.attention_0_0 = nn.MultiheadAttention(embed_dim=64, num_heads=4) + self.attention_0_1 = nn.MultiheadAttention(embed_dim=64, num_heads=8, bias=False, add_bias_kv=False, add_zero_attn=False) + self.attention_0_2 = nn.MultiheadAttention(embed_dim=64, num_heads=16, bias=True, add_bias_kv=False, add_zero_attn=False) + + self.attention_0_3 = nn.MultiheadAttention(embed_dim=32, num_heads=8, bias=True) + self.attention_0_33 = nn.MultiheadAttention(embed_dim=32, num_heads=8, bias=True) + + self.attention_0_4 = nn.MultiheadAttention(embed_dim=40, num_heads=4, kdim=30, vdim=20) + self.attention_0_5 = nn.MultiheadAttention(embed_dim=40, num_heads=8, kdim=30, vdim=20, bias=False, add_bias_kv=False, add_zero_attn=False) + self.attention_0_6 = nn.MultiheadAttention(embed_dim=40, num_heads=10, kdim=30, vdim=20, bias=True, add_bias_kv=False, add_zero_attn=False) + + if version.parse(torch.__version__) >= version.parse('1.9'): + self.attention_1_0 = nn.MultiheadAttention(embed_dim=64, num_heads=4, batch_first=True) + self.attention_1_1 = nn.MultiheadAttention(embed_dim=64, num_heads=8, bias=False, add_bias_kv=False, add_zero_attn=False, batch_first=True) + self.attention_1_2 = nn.MultiheadAttention(embed_dim=64, num_heads=16, bias=True, add_bias_kv=False, add_zero_attn=False, batch_first=True) + + self.attention_1_3 = nn.MultiheadAttention(embed_dim=32, num_heads=8, bias=True, batch_first=True) + self.attention_1_33 = nn.MultiheadAttention(embed_dim=32, num_heads=8, bias=True, batch_first=True) + + self.attention_1_4 = nn.MultiheadAttention(embed_dim=40, num_heads=4, kdim=30, vdim=20, batch_first=True) + self.attention_1_5 = nn.MultiheadAttention(embed_dim=40, num_heads=8, kdim=30, vdim=20, bias=False, add_bias_kv=False, add_zero_attn=False, batch_first=True) + self.attention_1_6 = nn.MultiheadAttention(embed_dim=40, num_heads=10, kdim=30, vdim=20, bias=True, add_bias_kv=False, add_zero_attn=False, batch_first=True) + + def forward(self, xq, xk, xv, z, zmask, yq, yk, yv, ymask, ymask2): + x0, x0w = self.attention_0_0(xq, xk, xv) + x1, x1w = self.attention_0_1(xq, xk, xv) + x2, x2w = self.attention_0_2(xq, xk, xk) + + x3, _ = self.attention_0_3(z, z, z, need_weights=False) + x33, _ = self.attention_0_33(z, z, z, attn_mask=zmask) + + x4, x4w = self.attention_0_4(yq, yk, yv) + x5, x5w = self.attention_0_5(yq, yk, yv, attn_mask=ymask) + x6, x6w = self.attention_0_6(yq, yk, yv, attn_mask=ymask2) + + if version.parse(torch.__version__) < version.parse('1.9'): + return x0, x0w, x1, x1w, x2, x2w, x3, x33, x4, x4w, x5, x5w, x6, x6w + + xq = xq.transpose(0, 1) + xk = xk.transpose(0, 1) + xv = xv.transpose(0, 1) + z = z.transpose(0, 1) + yq = yq.transpose(0, 1) + yk = yk.transpose(0, 1) + yv = yv.transpose(0, 1) + + y0, y0w = self.attention_1_0(xq, xk, xv) + y1, y1w = self.attention_1_1(xq, xk, xv) + y2, y2w = self.attention_1_2(xq, xk, xk) + + y3, _ = self.attention_1_3(z, z, z) + if version.parse(torch.__version__) >= version.parse('1.12') and version.parse(torch.__version__) < version.parse('1.13'): + # HACK pytorch 1.12 breaks 2-dim zmask + # https://github.com/pytorch/pytorch/issues/97409 + # zmask2 = zmask.reshape(1, 1, 30, 30).expand(1, 8, 30, 30) + # y33, _ = self.attention_1_33(z, z, z, attn_mask=zmask2) + # but it produce all nan then, skip test :( + y33 = y3.relu() + elif version.parse(torch.__version__) >= version.parse('2.0') and version.parse(torch.__version__) < version.parse('2.1'): + # HACK pytorch 2.0 produce all nan, skip test :( + y33 = y3.relu() + else: + y33, _ = self.attention_1_33(z, z, z, attn_mask=zmask.relu()) + + y4, y4w = self.attention_1_4(yq, yk, yv) + y5, y5w = self.attention_1_5(yq, yk, yv, attn_mask=ymask.relu()) + y6, y6w = self.attention_1_6(yq, yk, yv, attn_mask=ymask2.relu()) + + return x0, x0w, x1, x1w, x2, x2w, x3, x33, x4, x4w, x5, x5w, x6, x6w, y0, y0w, y1, y1w, y2, y2w, y3, y33, y4, y4w, y5, y5w, y6, y6w + +def test(): + if version.parse(torch.__version__) < version.parse('1.10'): + return True + + if version.parse(torch.__version__) >= version.parse('2.0') and version.parse(torch.__version__) < version.parse('2.1'): + return True + + net = Model() + net.eval() + + torch.manual_seed(0) + xq = torch.rand(20, 1, 64) + xk = torch.rand(20, 1, 64) + xv = torch.rand(20, 1, 64) + z = torch.rand(30, 1, 32) + zmask = torch.rand(30, 30) + yq = torch.rand(15, 1, 40) + yk = torch.rand(24, 1, 30) + yv = torch.rand(24, 1, 20) + ymask = torch.rand(15, 24) + ymask2 = torch.rand(10, 15, 24) + + a = net(xq, xk, xv, z, zmask, yq, yk, yv, ymask, ymask2) + + # export onnx + torch.onnx.export(net, (xq, xk, xv, z, zmask, yq, yk, yv, ymask, ymask2), "test_nn_MultiheadAttention.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_nn_MultiheadAttention.onnx inputshape=[20,1,64],[20,1,64],[20,1,64],[30,1,32],[30,30],[15,1,40],[24,1,30],[24,1,20],[15,24],[10,15,24]") + + # pnnx inference + import test_nn_MultiheadAttention_pnnx + b = test_nn_MultiheadAttention_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) From 3752d71200e67cb710918c1c535c9af7af1b95f1 Mon Sep 17 00:00:00 2001 From: nihui Date: Fri, 12 Jul 2024 10:25:37 +0800 Subject: [PATCH 7/8] fix potential fp16s bf16s conflicts on arm vfpv4 (#5578) * fix potential fp16s bf16s conflicts on armv7 vfpv4 * but prefer fp16 on armv8.2 --- src/net.cpp | 36 +++++++++++++++++++++++++++++++++--- tests/testutil.cpp | 18 ++++++++++++++++-- 2 files changed, 49 insertions(+), 5 deletions(-) diff --git a/src/net.cpp b/src/net.cpp index 996337ba36a..3574944e726 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -621,8 +621,17 @@ int NetPrivate::convert_layout(Mat& bottom_blob, const Layer* layer, const Optio // clang-format off // *INDENT-OFF* +#if NCNN_ARM82 + if (opt.use_fp16_storage && cpu_support_arm_asimdhp() && layer->support_fp16_storage) + { + Mat bottom_blob_fp16; + cast_float32_to_float16(bottom_blob, bottom_blob_fp16, opt); + bottom_blob = bottom_blob_fp16; + } + else +#endif // NCNN_ARM82 #if NCNN_VFPV4 - if (opt.use_fp16_storage && cpu_support_arm_vfpv4() && layer->support_fp16_storage) + if (opt.use_fp16_storage && !opt.use_bf16_storage && cpu_support_arm_vfpv4() && layer->support_fp16_storage) { Mat bottom_blob_fp16; cast_float32_to_float16(bottom_blob, bottom_blob_fp16, opt); @@ -740,8 +749,17 @@ int NetPrivate::convert_layout(Mat& bottom_blob, const Layer* layer, const Optio // clang-format off // *INDENT-OFF* +#if NCNN_ARM82 + if (opt.use_fp16_storage && cpu_support_arm_asimdhp() && !layer->support_fp16_storage) + { + Mat bottom_blob_fp32; + cast_float16_to_float32(bottom_blob, bottom_blob_fp32, opt); + bottom_blob = bottom_blob_fp32; + } + else +#endif // NCNN_ARM82 #if NCNN_VFPV4 - if (opt.use_fp16_storage && cpu_support_arm_vfpv4() && !layer->support_fp16_storage) + if (opt.use_fp16_storage && !opt.use_bf16_storage && cpu_support_arm_vfpv4() && !layer->support_fp16_storage) { Mat bottom_blob_fp32; cast_float16_to_float32(bottom_blob, bottom_blob_fp32, opt); @@ -2719,8 +2737,20 @@ int Extractor::extract(int blob_index, Mat& feat, int type) // clang-format off // *INDENT-OFF* +#if NCNN_ARM82 + if (d->opt.use_fp16_storage && cpu_support_arm_asimdhp() && (type == 0)) + { + if (feat.elembits() == 16) + { + Mat feat_fp32; + cast_float16_to_float32(feat, feat_fp32, d->opt); + feat = feat_fp32; + } + } + else +#endif // NCNN_ARM82 #if NCNN_VFPV4 - if (d->opt.use_fp16_storage && cpu_support_arm_vfpv4() && (type == 0)) + if (d->opt.use_fp16_storage && !d->opt.use_bf16_storage && cpu_support_arm_vfpv4() && (type == 0)) { if (feat.elembits() == 16) { diff --git a/tests/testutil.cpp b/tests/testutil.cpp index 07d95547d44..837043cb754 100644 --- a/tests/testutil.cpp +++ b/tests/testutil.cpp @@ -328,8 +328,15 @@ static int convert_to_optimal_layout(const ncnn::Mat& a, ncnn::Mat& a4, const nc { // clang-format off // *INDENT-OFF* +#if NCNN_ARM82 + if (opt.use_fp16_storage && ncnn::cpu_support_arm_asimdhp() && op->support_fp16_storage && !(flag & TEST_LAYER_DISABLE_AUTO_INPUT_CASTING)) + { + ncnn::cast_float32_to_float16(a, a4, opt); + } + else +#endif // NCNN_ARM82 #if NCNN_VFPV4 - if (opt.use_fp16_storage && ncnn::cpu_support_arm_vfpv4() && op->support_fp16_storage && !(flag & TEST_LAYER_DISABLE_AUTO_INPUT_CASTING)) + if (opt.use_fp16_storage && !opt.use_bf16_storage && ncnn::cpu_support_arm_vfpv4() && op->support_fp16_storage && !(flag & TEST_LAYER_DISABLE_AUTO_INPUT_CASTING)) { ncnn::cast_float32_to_float16(a, a4, opt); } @@ -449,8 +456,15 @@ static int convert_to_vanilla_layout(const ncnn::Mat& c4, ncnn::Mat& c, const nc // clang-format off // *INDENT-OFF* +#if NCNN_ARM82 + if (opt.use_fp16_storage && ncnn::cpu_support_arm_asimdhp() && op->support_fp16_storage && c4_unpacked.elembits() == 16) + { + ncnn::cast_float16_to_float32(c4_unpacked, c, opt); + } + else +#endif // NCNN_ARM82 #if NCNN_VFPV4 - if (opt.use_fp16_storage && ncnn::cpu_support_arm_vfpv4() && op->support_fp16_storage && c4_unpacked.elembits() == 16) + if (opt.use_fp16_storage && !opt.use_bf16_storage && ncnn::cpu_support_arm_vfpv4() && op->support_fp16_storage && c4_unpacked.elembits() == 16) { ncnn::cast_float16_to_float32(c4_unpacked, c, opt); } From 1c40615b2dbeac41cffa9738d4888b822e1509ba Mon Sep 17 00:00:00 2001 From: nihui Date: Fri, 12 Jul 2024 14:36:10 +0800 Subject: [PATCH 8/8] pnnx convert onnx sdap reduce min/max/mean/sum/prod (#5579) * pnnx convert onnx sdap * test reduce --- .../F_scaled_dot_product_attention.cpp | 91 +++++++++++++++++++ tools/pnnx/src/pass_level2/torch_max.cpp | 80 +++++++++++++--- tools/pnnx/src/pass_level2/torch_min.cpp | 80 +++++++++++++--- tools/pnnx/src/pass_level2/torch_prod.cpp | 32 ++++--- tools/pnnx/tests/onnx/CMakeLists.txt | 8 +- .../test_F_scaled_dot_product_attention.py | 64 +++++++++++++ tools/pnnx/tests/onnx/test_torch_max.py | 62 +++++++++++++ tools/pnnx/tests/onnx/test_torch_mean.py | 60 ++++++++++++ tools/pnnx/tests/onnx/test_torch_min.py | 62 +++++++++++++ tools/pnnx/tests/onnx/test_torch_prod.py | 60 ++++++++++++ tools/pnnx/tests/onnx/test_torch_sum.py | 60 ++++++++++++ 11 files changed, 617 insertions(+), 42 deletions(-) create mode 100644 tools/pnnx/tests/onnx/test_F_scaled_dot_product_attention.py create mode 100644 tools/pnnx/tests/onnx/test_torch_max.py create mode 100644 tools/pnnx/tests/onnx/test_torch_mean.py create mode 100644 tools/pnnx/tests/onnx/test_torch_min.py create mode 100644 tools/pnnx/tests/onnx/test_torch_prod.py create mode 100644 tools/pnnx/tests/onnx/test_torch_sum.py diff --git a/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp b/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp index 36ca3c334f2..9fba1e770cc 100644 --- a/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp +++ b/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp @@ -80,4 +80,95 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_1, 10) +static bool NearlyEqual(float a, float b, float epsilon) +{ + if (a == b) + return true; + + float diff = (float)fabs(a - b); + if (diff <= epsilon) + return true; + + // relative error + return diff < epsilon * std::max(fabs(a), fabs(b)); +} + +class F_scaled_dot_product_attention_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +12 11 +pnnx.Input input_0 0 1 query +pnnx.Input input_1 0 1 key +pnnx.Input input_2 0 1 value +Transpose op_0 1 1 key kt perm=(0,1,3,2) +prim::Constant op_1 0 1 scale value=%sqrt_scale +aten::mul op_2 2 1 query scale q +prim::Constant op_3 0 1 scale2 value=%sqrt_scale +aten::mul op_4 2 1 kt scale2 k +MatMul op_5 2 1 q k qk +Softmax op_6 1 1 qk 4 axis=-1 +MatMul op_7 2 1 4 value out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.scaled_dot_product_attention"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["dropout_p"] = 0.f; + op->params["is_causal"] = false; + + const float sqrt_scale = captured_params.at("sqrt_scale").f; + const float scale = sqrt_scale * sqrt_scale; + + op->params["scale"] = scale; + + if (!op->inputs[0]->shape.empty()) + { + const int embed_dim = op->inputs[0]->shape[op->inputs[0]->shape.size() - 1]; + if (NearlyEqual(scale, 1.f / sqrt(embed_dim), 0.001)) + { + // drop scale=None for compatibility with old torch + op->params.erase("scale"); + } + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_onnx, 10) + +class F_scaled_dot_product_attention_onnx_1 : public F_scaled_dot_product_attention_onnx +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +14 13 +pnnx.Input input_0 0 1 query +pnnx.Input input_1 0 1 key +pnnx.Input input_2 0 1 value +pnnx.Input input_3 0 1 attn_mask +Transpose op_0 1 1 key kt perm=(0,1,3,2) +prim::Constant op_1 0 1 scale value=%sqrt_scale +aten::mul op_2 2 1 query scale q +prim::Constant op_3 0 1 scale2 value=%sqrt_scale +aten::mul op_4 2 1 kt scale2 k +MatMul op_5 2 1 q k qk +aten::add op_6 2 1 qk attn_mask qkm +Softmax op_7 1 1 qkm 4 axis=-1 +MatMul op_8 2 1 4 value out +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_onnx_1, 10) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_max.cpp b/tools/pnnx/src/pass_level2/torch_max.cpp index 68479b85d5b..b606fed066b 100644 --- a/tools/pnnx/src/pass_level2/torch_max.cpp +++ b/tools/pnnx/src/pass_level2/torch_max.cpp @@ -83,30 +83,80 @@ pnnx.Output output 1 0 out if (captured_params.find("op_0.axes") != captured_params.end()) { op->params["dim"] = captured_params.at("op_0.axes"); - } - else - { - // reduce all - const int input_rank = (int)op->inputs[0]->shape.size(); - std::vector dim(input_rank); - for (int i = 0; i < input_rank; i++) + + if (captured_params.find("op_0.keepdims") != captured_params.end()) { - dim[i] = i; + op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false; + } + else + { + op->params["keepdim"] = true; } - op->params["dim"] = dim; - } - - if (captured_params.find("op_0.keepdims") != captured_params.end()) - { - op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false; } else { - op->params["keepdim"] = true; + // reduce all } } }; REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_max_onnx, 20) +class torch_max_onnx_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +ReduceMax op_0 1 1 input out %*=%* +ArgMax op_1 1 1 input indices %*=%* +pnnx.Output output 2 0 out indices +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.max"; + } + + bool match(const std::map& captured_params) const + { + if (captured_params.find("op_0.axes") == captured_params.end()) + return false; + + if (captured_params.find("op_0.keepdims") == captured_params.end()) + return false; + + if (captured_params.find("op_1.axis") == captured_params.end()) + return false; + + if (captured_params.find("op_1.keepdims") == captured_params.end()) + return false; + + if (captured_params.at("op_0.axes").type != 5 || captured_params.at("op_0.axes").ai.size() != 1) + return false; + + if (captured_params.at("op_1.axis").type != 2) + return false; + + if (captured_params.at("op_0.axes").ai[0] != captured_params.at("op_1.axis").i) + return false; + + if (captured_params.at("op_0.keepdims").i != captured_params.at("op_1.keepdims").i) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["dim"] = captured_params.at("op_0.axes").ai[0]; + op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_max_onnx_1, 19) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_min.cpp b/tools/pnnx/src/pass_level2/torch_min.cpp index c5e48bbc64b..35cc4988a19 100644 --- a/tools/pnnx/src/pass_level2/torch_min.cpp +++ b/tools/pnnx/src/pass_level2/torch_min.cpp @@ -83,30 +83,80 @@ pnnx.Output output 1 0 out if (captured_params.find("op_0.axes") != captured_params.end()) { op->params["dim"] = captured_params.at("op_0.axes"); - } - else - { - // reduce all - const int input_rank = (int)op->inputs[0]->shape.size(); - std::vector dim(input_rank); - for (int i = 0; i < input_rank; i++) + + if (captured_params.find("op_0.keepdims") != captured_params.end()) { - dim[i] = i; + op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false; + } + else + { + op->params["keepdim"] = true; } - op->params["dim"] = dim; - } - - if (captured_params.find("op_0.keepdims") != captured_params.end()) - { - op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false; } else { - op->params["keepdim"] = true; + // reduce all } } }; REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_min_onnx, 20) +class torch_min_onnx_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +ReduceMin op_0 1 1 input out %*=%* +ArgMin op_1 1 1 input indices %*=%* +pnnx.Output output 2 0 out indices +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.min"; + } + + bool match(const std::map& captured_params) const + { + if (captured_params.find("op_0.axes") == captured_params.end()) + return false; + + if (captured_params.find("op_0.keepdims") == captured_params.end()) + return false; + + if (captured_params.find("op_1.axis") == captured_params.end()) + return false; + + if (captured_params.find("op_1.keepdims") == captured_params.end()) + return false; + + if (captured_params.at("op_0.axes").type != 5 || captured_params.at("op_0.axes").ai.size() != 1) + return false; + + if (captured_params.at("op_1.axis").type != 2) + return false; + + if (captured_params.at("op_0.axes").ai[0] != captured_params.at("op_1.axis").i) + return false; + + if (captured_params.at("op_0.keepdims").i != captured_params.at("op_1.keepdims").i) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["dim"] = captured_params.at("op_0.axes").ai[0]; + op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_min_onnx_1, 19) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_prod.cpp b/tools/pnnx/src/pass_level2/torch_prod.cpp index 7f15c2ba88a..51b614ec0dc 100644 --- a/tools/pnnx/src/pass_level2/torch_prod.cpp +++ b/tools/pnnx/src/pass_level2/torch_prod.cpp @@ -58,24 +58,34 @@ pnnx.Output output 1 0 out return "torch.prod"; } + bool match(const std::map& captured_params) const + { + if (captured_params.find("op_0.axes") == captured_params.end()) + return false; + + if (captured_params.at("op_0.axes").type != 2 && captured_params.at("op_0.axes").type != 5) + return false; + + if (captured_params.at("op_0.axes").type == 5 && captured_params.at("op_0.axes").ai.size() > 1) + return false; + + return true; + } + void write(Operator* op, const std::map& captured_params) const { - if (captured_params.find("op_0.axes") != captured_params.end()) + int dim; + if (captured_params.at("op_0.axes").type == 2) { - op->params["dim"] = captured_params.at("op_0.axes"); + dim = captured_params.at("op_0.axes").i; } - else + else // if (captured_params.at("op_0.axes").type == 5) { - // reduce all - const int input_rank = (int)op->inputs[0]->shape.size(); - std::vector dim(input_rank); - for (int i = 0; i < input_rank; i++) - { - dim[i] = i; - } - op->params["dim"] = dim; + dim = captured_params.at("op_0.axes").ai[0]; } + op->params["dim"] = dim; + if (captured_params.find("op_0.keepdims") != captured_params.end()) { op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false; diff --git a/tools/pnnx/tests/onnx/CMakeLists.txt b/tools/pnnx/tests/onnx/CMakeLists.txt index 12d816cd8e2..0c0a136fbaf 100644 --- a/tools/pnnx/tests/onnx/CMakeLists.txt +++ b/tools/pnnx/tests/onnx/CMakeLists.txt @@ -36,7 +36,7 @@ pnnx_onnx_add_test(F_pad) pnnx_onnx_add_test(F_prelu) pnnx_onnx_add_test(F_relu) pnnx_onnx_add_test(F_relu6) -# pnnx_onnx_add_test(F_scaled_dot_product_attention) +pnnx_onnx_add_test(F_scaled_dot_product_attention) pnnx_onnx_add_test(F_sigmoid) pnnx_onnx_add_test(F_softmax) pnnx_onnx_add_test(F_upsample_bilinear) @@ -103,3 +103,9 @@ pnnx_onnx_add_test(shufflenet_v2_x1_0) pnnx_onnx_add_test(squeezenet1_1) pnnx_onnx_add_test(swin_t) pnnx_onnx_add_test(vit_b_32) + +pnnx_onnx_add_test(torch_max) +pnnx_onnx_add_test(torch_mean) +pnnx_onnx_add_test(torch_min) +pnnx_onnx_add_test(torch_prod) +pnnx_onnx_add_test(torch_sum) diff --git a/tools/pnnx/tests/onnx/test_F_scaled_dot_product_attention.py b/tools/pnnx/tests/onnx/test_F_scaled_dot_product_attention.py new file mode 100644 index 00000000000..2802b3794a7 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_F_scaled_dot_product_attention.py @@ -0,0 +1,64 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, q, k, v, m): + x = F.scaled_dot_product_attention(q, k, v) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=m) + return x, y + +def test(): + if version.parse(torch.__version__) < version.parse('2.1'): + return True + + net = Model() + net.eval() + + torch.manual_seed(0) + q = torch.rand(3, 8, 128, 64) + k = torch.rand(3, 8, 48, 64) + v = torch.rand(3, 8, 48, 77) + m = torch.rand(3, 8, 128, 48) + + a = net(q, k, v, m) + + # export onnx + torch.onnx.export(net, (q, k, v, m), "test_F_scaled_dot_product_attention.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_F_scaled_dot_product_attention.onnx inputshape=[3,8,128,64],[3,8,48,64],[3,8,48,77],[3,8,128,48]") + + # pnnx inference + import test_F_scaled_dot_product_attention_pnnx + b = test_F_scaled_dot_product_attention_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_torch_max.py b/tools/pnnx/tests/onnx/test_torch_max.py new file mode 100644 index 00000000000..0ab18bec47d --- /dev/null +++ b/tools/pnnx/tests/onnx/test_torch_max.py @@ -0,0 +1,62 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x, x_indices = torch.max(x, dim=1, keepdim=False) + y = torch.max(y) + w = torch.max(z, w) + z, z_indices = torch.max(z, dim=0, keepdim=True) + return x, x_indices, y, z, z_indices, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + w = torch.rand(5, 9, 10) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_torch_max.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_torch_max.onnx inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10],[5,9,10]") + + # pnnx inference + import test_torch_max_pnnx + b = test_torch_max_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_torch_mean.py b/tools/pnnx/tests/onnx/test_torch_mean.py new file mode 100644 index 00000000000..cf599057579 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_torch_mean.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.mean(x, dim=1, keepdim=False) + y = torch.mean(y, dim=(2,3), keepdim=False) + z = torch.mean(z, dim=0, keepdim=True) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + + a = net(x, y, z) + + # export onnx + torch.onnx.export(net, (x, y, z), "test_torch_mean.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_torch_mean.onnx inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") + + # pnnx inference + import test_torch_mean_pnnx + b = test_torch_mean_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_torch_min.py b/tools/pnnx/tests/onnx/test_torch_min.py new file mode 100644 index 00000000000..e41584afe0c --- /dev/null +++ b/tools/pnnx/tests/onnx/test_torch_min.py @@ -0,0 +1,62 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x, x_indices = torch.min(x, dim=1, keepdim=False) + y = torch.min(y) + w = torch.min(z, w) + z, z_indices = torch.min(z, dim=0, keepdim=True) + return x, x_indices, y, z, z_indices, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + w = torch.rand(5, 9, 10) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_torch_min.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_torch_min.onnx inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10],[5,9,10]") + + # pnnx inference + import test_torch_min_pnnx + b = test_torch_min_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_torch_prod.py b/tools/pnnx/tests/onnx/test_torch_prod.py new file mode 100644 index 00000000000..c36b97c2a31 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_torch_prod.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.prod(x, dim=1, keepdim=False) + y = torch.prod(y, dim=2, keepdim=False) + z = torch.prod(z, dim=0, keepdim=True) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + + a = net(x, y, z) + + # export onnx + torch.onnx.export(net, (x, y, z), "test_torch_prod.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_torch_prod.onnx inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") + + # pnnx inference + import test_torch_prod_pnnx + b = test_torch_prod_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_torch_sum.py b/tools/pnnx/tests/onnx/test_torch_sum.py new file mode 100644 index 00000000000..3ae6412f09b --- /dev/null +++ b/tools/pnnx/tests/onnx/test_torch_sum.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.sum(x, dim=1, keepdim=False) + y = torch.sum(y, dim=(2,3), keepdim=False) + z = torch.sum(z, dim=0, keepdim=True) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + + a = net(x, y, z) + + # export onnx + torch.onnx.export(net, (x, y, z), "test_torch_sum.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_torch_sum.onnx inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") + + # pnnx inference + import test_torch_sum_pnnx + b = test_torch_sum_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)