Skip to content

Latest commit

 

History

History
42 lines (30 loc) · 912 Bytes

File metadata and controls

42 lines (30 loc) · 912 Bytes

Collaborative Learning

AIJack allows you to easily experiment collaborative learning such as federated learning and split learning. All you have to do is add a few lines of code to the regular pytorch code.

  • federated learning
clients = [TorchModule(), TorcnModule()]
global_model = TorchModule()
server = FedAvgServer(clients, global_model)

for _ in range(epoch):

  for client in clients:
    normal pytorch training.

  server.update()
  server.distribtue()
  • split learning
client_1 = SplitNNClient(first_model, user_id=0)
client_2 = SplitNNClient(second_model, user_id=1)
clients = [client_1, client_2]
splitnn = SplitNN(clients)

for _ in range(epoch):
  for x, y in dataloader:

    for opt in optimizers:
      opt.zero_grad()

    pred = splitnn(x)
    loss = criterion(y, pred)
    loss.backwad()
    splitnn.backward()

    for opt in optimizers:
      opt.step()