Skip to content

Commit

Permalink
[Modify] Abstract the layers in SWNN with class LinearNetwork
Browse files Browse the repository at this point in the history
  • Loading branch information
Y-nuclear committed Apr 14, 2024
1 parent 07208d8 commit 2a6ad0f
Showing 1 changed file with 73 additions and 62 deletions.
135 changes: 73 additions & 62 deletions src/gnnwr/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,68 @@ def default_dense_layer(insize, outsize):
size = int(math.pow(2, int(math.log2(size)) - 1))
return dense_layer

class LinearNetwork(nn.Module):
"""
LinearNetwork is a neural network with dense layers, which is used to calculate the weight of features.
| The each layer of LinearNetwork is as follows:
| full connection layer -> batch normalization layer -> activate function -> drop out layer
Parameters
----------
dense_layer: list
a list of dense layers of Neural Network
insize: int
input size of Neural Network(must be positive)
outsize: int
Output size of Neural Network(must be positive)
drop_out: float
drop out rate(default: ``0.2``)
activate_func: torch.nn.functional
activate function(default: ``nn.PReLU(init=0.1)``)
batch_norm: bool
whether use batch normalization(default: ``True``)
"""
def __init__(self, insize, outsize, drop_out=0, activate_func=None, batch_norm=False):
super(LinearNetwork, self).__init__()
self.layer = nn.Linear(insize, outsize)
if drop_out < 0 or drop_out > 1:
raise ValueError("drop_out must be in [0, 1]")
elif drop_out == 0:
self.drop_out = nn.Identity()
else:
self.drop_out = nn.Dropout(drop_out)
if batch_norm:
self.batch_norm = nn.BatchNorm1d(outsize)
else:
self.batch_norm = nn.Identity()

if activate_func is None:
self.activate_func = nn.Identity()
else:
self.activate_func = activate_func
self.reset_parameter()

def reset_parameter(self):
torch.nn.init.kaiming_uniform_(self.layer.weight, a=0, mode='fan_in')
if self.layer.bias is not None:
self.layer.bias.data.fill_(0)

def forward(self, x):
x = x.to(torch.float32)
x = self.layer(x)
x = self.batch_norm(x)
x = self.activate_func(x)
x = self.drop_out(x)
return x

def __str__(self) -> str:
return f"LinearNetwork: {self.layer.in_features} -> {self.layer.out_features}\n" + \
f"Dropout: {self.drop_out.p}\n" + \
f"BatchNorm: {self.batch_norm}\n" + \
f"Activation: {self.activate_func}"

def __repr__(self) -> str:
return self.__str__()

class SWNN(nn.Module):
"""
Expand Down Expand Up @@ -68,26 +130,15 @@ def __init__(self, dense_layer=None, insize=-1, outsize=-1, drop_out=0.2, activa
self.fc = nn.Sequential()

for size in self.dense_layer:
# add full connection layer
self.fc.add_module("swnn_full" + str(count),
nn.Linear(lastsize, size, bias=True)) # add full connection layer
if batch_norm:
# add batch normalization layer if needed
self.fc.add_module("swnn_batc" + str(count), nn.BatchNorm1d(size))
self.fc.add_module("swnn_acti" + str(count), self.activate_func) # add activate function
self.fc.add_module("swnn_drop" + str(count),
nn.Dropout(self.drop_out)) # add drop_out layer
lastsize = size # update the size of last layer
LinearNetwork(lastsize, size, drop_out, activate_func, batch_norm))
lastsize = size
count += 1
self.fc.add_module("full" + str(count),
nn.Linear(lastsize, self.outsize)) # add the last full connection layer
for m in self.modules():
if isinstance(m, nn.Linear):
torch.nn.init.kaiming_uniform_(m.weight, a=0, mode='fan_in')
if m.bias is not None:
m.bias.data.fill_(0)

LinearNetwork(lastsize, self.outsize))
def forward(self, x):
x.to(torch.float32)
x = x.to(torch.float32)
x = self.fc(x)
return x

Expand Down Expand Up @@ -129,27 +180,15 @@ def __init__(self, dense_layer, insize, outsize, drop_out=0.2, activate_func=nn.
self.fc = nn.Sequential()
for size in self.dense_layer:
self.fc.add_module("stpnn_full" + str(count),
nn.Linear(lastsize, size)) # add full connection layer
if batch_norm:
# add batch normalization layer if needed
self.fc.add_module("stpnn_batc" + str(count), nn.BatchNorm1d(size))
self.fc.add_module("stpnn_acti" + str(count), self.activate_func) # add activate function
self.fc.add_module("stpnn_drop" + str(count),
nn.Dropout(self.drop_out)) # add drop_out layer
lastsize = size # update the size of last layer
LinearNetwork(lastsize, size, drop_out, activate_func, batch_norm))
lastsize = size
count += 1
self.fc.add_module("full" + str(count), nn.Linear(lastsize, self.outsize)) # add the last full connection layer
self.fc.add_module("acti" + str(count), nn.ReLU())

for m in self.modules():
if isinstance(m, nn.Linear):
torch.nn.init.kaiming_uniform_(m.weight)
if m.bias is not None:
m.bias.data.fill_(0)
self.fc.add_module("full" + str(count),
LinearNetwork(lastsize, self.outsize,activate_func=activate_func))

def forward(self, x):
# STPNN
x.to(torch.float32)
x = x.to(torch.float32)
batch = x.shape[0]
height = x.shape[1]
x = torch.reshape(x, shape=(batch * height, x.shape[2]))
Expand Down Expand Up @@ -196,32 +235,4 @@ def forward(self, input1):
STNN_output = self.STNN(STNN_input)
SPNN_output = self.SPNN(SPNN_input)
output = torch.cat((STNN_output, SPNN_output), dim=-1)
return output


# 权共享计算
def weight_share(model, x, output_size=1):
"""
weight_share is a function to calculate the output of neural network with weight sharing.
Parameters
----------
model: torch.nn.Module
neural network with weight sharing
x: torch.Tensor
input of neural network
output_size: int
output size of neural network
Returns
-------
output: torch.Tensor
output of neural network
"""
x.to(torch.float32)
batch = x.shape[0]
height = x.shape[1]
x = torch.reshape(x, shape=(batch * height, x.shape[2]))
output = model(x)
output = torch.reshape(output, shape=(batch, height, output_size))
return output
return output

0 comments on commit 2a6ad0f

Please sign in to comment.