From 7140dd1227cebc23da11f2d7ebd5fbedd70251cb Mon Sep 17 00:00:00 2001 From: Hisham Date: Sat, 6 Jul 2024 04:50:08 +0800 Subject: [PATCH] Fix PReLU Broadcasting Bug for Multiple Parameters #################Summary################# Fixed a bug in the PReLU function in jittor/nn.py where broadcasting the weight parameter caused errors when num_parameters was greater than 1. The previous implementation did not correctly broadcast the weights to match the input dimensions, leading to runtime errors. #################Changes Made################# Modified the execute method in PReLU class to correctly broadcast weight parameter for cases where num_parameters is greater than 1. #################Code Changes################# #################Original Code:################# def __init__(self, num_parameters=1, init_=0.25): self.num_parameters = num_parameters self.weight = init.constant((num_parameters,), "float32", init_) def execute(self, x): if self.num_parameters != 1: assert self.num_parameters == x.size(1), f"num_parameters does not match input channels in PReLU" return jt.maximum(0, x) + self.weight.broadcast(x, [0,2,3]) * jt.minimum(0, x) else: return jt.maximum(0, x) + self.weight * jt.minimum(0, x) ############Updated Code:############## def __init__(self, num_parameters=1, init_=0.25): self.num_parameters = num_parameters self.weight = init.constant((num_parameters,), "float32", init_) def execute(self, x): if self.num_parameters != 1: assert self.num_parameters == x.shape[1], f"num_parameters does not match input channels in PReLU" weight_broadcasted = self.weight.broadcast([x.shape[0], self.num_parameters, *([1] * (len(x.shape) - 2))]) return jt.maximum(0, x) + weight_broadcasted * jt.minimum(0, x) else: return jt.maximum(0, x) + self.weight * jt.minimum(0, x) #################Testing################# Tested the updated PReLU function with various configurations to ensure proper functionality: import jittor as jt from jittor import nn # Create input data with the specified shape def create_input_data(shape): num_elements = 1 for dim in shape: num_elements *= dim return jt.array(list(range(-num_elements // 2, num_elements // 2)), dtype=jt.float32).reshape(shape) # Test the PReLU activation function def test_prelu(num_parameters, input_shape): prelu_layer = nn.PReLU(num_parameters=num_parameters) input_data = create_input_data(input_shape) print(f"Testing PReLU with num_parameters={num_parameters} and input_shape={input_shape}") print(f"Input Data:\n{input_data.numpy()}") output_data = prelu_layer(input_data) print(f"Output Data (PReLU):\n{output_data.numpy()}\n") if __name__ == "__main__": test_configs = [ (1, (5,)), # Single parameter (5, (5, 5)), # Five parameters matching the number of channels (3, (3, 3)), # Three parameters matching the number of channels ] for num_parameters, input_shape in test_configs: test_prelu(num_parameters, input_shape) #################Test Results:################# Testing PReLU with num_parameters=1 and input_shape=(5,) Input Data: [-3. -2. -1. 0. 1.] Output Data (PReLU): [-0.75 -0.5 -0.25 0. 1. ] Testing PReLU with num_parameters=5 and input_shape=(5, 5) Input Data: [[-13. -12. -11. -10. -9.] [ -8. -7. -6. -5. -4.] [ -3. -2. -1. 0. 1.] [ 2. 3. 4. 5. 6.] [ 7. 8. 9. 10. 11.]] Output Data (PReLU): [[-3.25 -3. -2.75 -2.5 -2.25] [-2. -1.75 -1.5 -1.25 -1. ] [-0.75 -0.5 -0.25 0. 1. ] [ 2. 3. 4. 5. 6. ] [ 7. 8. 9. 10. 11. ]] Testing PReLU with num_parameters=3 and input_shape=(3, 3) Input Data: [[-5. -4. -3.] [-2. -1. 0.] [ 1. 2. 3.]] Output Data (PReLU): [[-1.25 -1. -0.75] [-0.5 -0.25 0. ] [ 1. 2. 3. ]] ################################## This fix ensures that the PReLU activation function can handle multiple parameters correctly by properly broadcasting the weight parameter to match the input tensor dimensions. --- python/jittor/nn.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 3208a179..22b934c9 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -364,10 +364,14 @@ def __init__(self, num_parameters=1, init_=0.25): self.num_parameters = num_parameters self.weight = init.constant((num_parameters,), "float32", init_) + def execute(self, x): if self.num_parameters != 1: - assert self.num_parameters == x.size(1), f"num_parameters does not match input channels in PReLU" - return jt.maximum(0, x) + self.weight.broadcast(x, [0,2,3]) * jt.minimum(0, x) + assert self.num_parameters == x.shape[1], f"num_parameters does not match input channels in PReLU" + # Adjust broadcasting logic to ensure it matches the input dimensions + shape = [x.shape[0], self.num_parameters] + [1] * (len(x.shape) - 2) + weight_broadcasted = self.weight.broadcast(shape) + return jt.maximum(0, x) + weight_broadcasted * jt.minimum(0, x) else: return jt.maximum(0, x) + self.weight * jt.minimum(0, x)