Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CNP #30

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Add CNP #30

wants to merge 3 commits into from

Conversation

AnirudhM1
Copy link
Contributor

@AnirudhM1 AnirudhM1 commented Aug 4, 2022

closes #25

Example:

from cnp import CNP, Aggregator

class Encoder(linen.Module):
    @linen.compact
    def __call__(self, x, y, train=True):
        # x : (n_classes * shots, 20)
        # y : (n_classes * shots, 1)
        
        y = y.squeeze() # (n_classes, )
        x = linen.Dense(32)(x)
        x = linen.relu(x)
        x = linen.Dense(64)(x)
        x = linen.relu(x)
        r = linen.Dense(32)(x) # (n_classes * shots, 32)

        aggregates = Aggregator.classification(r, y, ways)
        return aggregates # (ways * 32, )
class Decoder(linen.Module):
    @linen.compact
    def __call__(self, r, x, train=True):
        # r : (ways * 32, )
        # x : (batch_size, 20)

        x = jax.vmap(lambda a, b: jnp.concatenate([a, b]), in_axes=(0, None))(x, r) # (batch_size, 20 + ways * 32)
        x = linen.Dense(32)(x)
        x = linen.relu(x)
        x = linen.Dense(64)(x)
        x = linen.relu(x)
        x = linen.Dense(32)(x)
        x = linen.relu(x)
        x = linen.Dense(ways)(x)
        return x
key = jax.random.PRNGKey(0)
encoder_init_key, decoder_init_key = jax.random.split(key)

encoder = Encoder()
tx = optax.adam(1e-3)
params = encoder.init(encoder_init_key, jnp.ones((1, 20)), jnp.ones((1, 1)))
encoder_state = TrainState.create(apply_fn=encoder.apply, params=params, tx=tx)

decoder = Decoder()
params = decoder.init(decoder_init_key, jnp.ones((ways*32,)), jnp.ones((1, 20)))
decoder_state = TrainState.create(apply_fn=decoder.apply, params=params, tx=tx)
n_epochs = 100
task_count=50
for epoch in range(n_epochs):
    # Sample and train on a batch of tasks
    # for task in range(task_count):
    tasks = sample_batch(meta_train_kloader, task_count)

    encoder_state, decoder_state, loss, metrics = CNP.meta_train_step(encoder_state, decoder_state, tasks, loss_fn=cross_entropy, metrics=(accuracy, ))
        
    print('Epoch  % 2d Loss: %2.5e Avg Acc: %2.5f'%(epoch+1, loss, metrics[0]))
    display.clear_output(wait=True)

cc: @veds12

@AnirudhM1
Copy link
Contributor Author

AnirudhM1 commented Aug 4, 2022

A standard CNP model contains 3 parts:

  • Encoder: ri = h(xi, yi ; θ)
  • Aggregator: r = r1 ⊕ r2⊕ ... ⊕ rn
  • Decoder: logits = g(x, y ; Φ)

For now I have considered the aggregator part to be included in the encoder. This is because the aggregator can be very different for different tasks. That said, I can include some common aggregators (regression, classification etc) and give an option to pass in a parameter to use one of the common aggregators. If this is being done, the exact function definition has to be discussed.

Any thoughts on this? @veds12

@AnirudhM1 AnirudhM1 closed this Aug 4, 2022
@AnirudhM1 AnirudhM1 reopened this Aug 4, 2022
@codecov-commenter
Copy link

codecov-commenter commented Aug 4, 2022

Codecov Report

Base: 0.00% // Head: 0.00% // No change to project coverage 👍

Coverage data is based on head (f8ca69e) compared to base (d136a70).
Patch coverage: 0.00% of modified lines in pull request are covered.

Additional details and impacted files
@@          Coverage Diff           @@
##            main     #30    +/-   ##
======================================
  Coverage   0.00%   0.00%            
======================================
  Files          5      11     +6     
  Lines         95     295   +200     
======================================
- Misses        95     295   +200     
Impacted Files Coverage Δ
jeta/cnp.py 0.00% <0.00%> (ø)
jeta/models.py 0.00% <0.00%> (ø)
jeta/maml.py 0.00% <0.00%> (ø)
jeta/data/base_dataset.py 0.00% <0.00%> (ø)
jeta/base.py 0.00% <0.00%> (ø)
jeta/anil.py 0.00% <0.00%> (ø)
jeta/data/task_dataset.py 0.00% <0.00%> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@veds12
Copy link
Member

veds12 commented Aug 8, 2022

@AnirudhM1 a few comments:

  • I think at this stage it would be good to have a separate module for different NNs (MLPs, LSTMs). So for encoder and decoder you can just call this module
  • We can have a separate module for different types of aggregators, since these will be reused in other algorithms as well
  • Please also add proper docstrings to the code

@veds12
Copy link
Member

veds12 commented Aug 21, 2022

@AnirudhM1 any updates on this?

@AnirudhM1
Copy link
Contributor Author

@veds12 Sorry for the delay

We can have a separate module for different types of aggregators, since these will be reused in other algorithms as well

I have added a utility class which has functions for some common aggregation algorithms.
These functions can be called from inside the Encoder class

Please also add proper docstrings to the code

I have added docstrings to the main code (cnp.py). The above code in the description of this PR is just an example of how to use the interface.

I think at this stage it would be good to have a separate module for different NNs (MLPs, LSTMs). So for encoder and decoder you can just call this module

Can you please elaborate a bit further on this?
Currently, this model works with an arbitrary architecture for encoder and decoder

@AnirudhM1 AnirudhM1 marked this pull request as ready for review August 23, 2022 15:16
@veds12
Copy link
Member

veds12 commented Aug 25, 2022

Can you please elaborate a bit further on this?
Currently, this model works with an arbitrary architecture for encoder and decoder

I mean having a separate file from which different architectures (MLP, RNNs, CNNs) can be loaded.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add Conditional Neural Processes (CNP)
3 participants