Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gencast_mini_demo.ipynb on AMD CPU #113

Open
dkokron opened this issue Dec 21, 2024 · 38 comments
Open

gencast_mini_demo.ipynb on AMD CPU #113

dkokron opened this issue Dec 21, 2024 · 38 comments

Comments

@dkokron
Copy link

dkokron commented Dec 21, 2024

I'm attempting to run the gencast_mini_demo.ipynb case on my home workstation without a GPU. The notebook recognizes that I don't have the correct software to run on the installed GPU and fails over to CPU (which is what want to happen).

Output from cell 22.
WARNING:2024-12-21 14:22:21,184:jax._src.xla_bridge:969: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

I've attached the stack trace I get from cell 23 (Autoregressive rollout (loop in python)).
gencast.failure.txt

Is this expected? Does GenCast require a GPU or TPU to work?

@andrewlkd
Copy link
Collaborator

Hey,

This looks like a splash attention related error. Splash attention is only supported on TPU.

You can try follow the GPU instructions to change attention mechanism, I believe this should work fine on CPU.

Note that without knowing the memory specifications of your device, I can't guarantee it won't run out of memory. We've also never run GenCast on CPU so cannot make any guarantees around its correctness.

Hope that helps!

Andrew

@dkokron
Copy link
Author

dkokron commented Dec 21, 2024

I will try your suggestion and report back here.

@dkokron
Copy link
Author

dkokron commented Dec 22, 2024

I followed the suggestion in the "Running Inference on GPU" section of cloud_vm_setup.md

task_config = ckpt.task_config
sampler_config = ckpt.sampler_config
noise_config = ckpt.noise_config
noise_encoder_config = ckpt.noise_encoder_config
denoiser_architecture_config = ckpt.denoiser_architecture_config
denoiser_architecture_config.sparse_transformer_config.attention_type = "triblockdiag_mha"
denoiser_architecture_config.sparse_transformer_config.mask_type = "full"

The job (4 time steps and 8 members) ran for about 2h:30m using 17GB of system RAM with an averaged CPU load of ~30 (I have 48 cores). Unfortunately, the results are all NaN.

GenCast/graphcast/GenCast/lib/python3.12/site-packages/numpy/lib/_nanfunctions_impl.py:1409: RuntimeWarning: All-NaN slice encountered
return _nanquantile_unchecked(

@andrewlkd
Copy link
Collaborator

I can't say I've seen this warning before. Could you confirm if the entire forecast was NaN? Note that we expect NaNs in the sea surface temperature variable so I wonder if this is what you might be encountering.

@dkokron
Copy link
Author

dkokron commented Dec 24, 2024

I was plotting 2m_temp for all 8 ensemble members. All members had this same warning. I'll need to run it again to view other variables.

@dkokron
Copy link
Author

dkokron commented Dec 24, 2024

specific humidity at 850 and 100, vertical speed at 850, geopotential at 500 and u and v components of wind at 925 are also NaN. I did not look at the rest.

@dkokron
Copy link
Author

dkokron commented Dec 29, 2024

Any more ideas on how to investigate this issue?

@andrewlkd
Copy link
Collaborator

Unfortunately, we've never attempted to run the model on a CPU as this is too slow for practical uses. In principal there should be no reason why it should differ but unexpected device-specific compilation issues may be manifesting here. In the mean time hopefully the instructions on how to use free cloud compute are useful.

Do let us know if you gain any insights on why this is happening.

@dkokron
Copy link
Author

dkokron commented Jan 7, 2025

If you've never attempted to run it on a decent CPU, then how do you know it won't be practical?
I'll see if I can figure out what is going wrong and report back here.

@guidov
Copy link

guidov commented Jan 7, 2025

I also think it would be nice to be able to set up the model config and run it for one timestep on a our own CPU systems and then move it to cloud GPU or TPU. CPU systems have very large RAM nowadays.
I tried playing with this and got nan's also.

I set this in the notebook , but if the CPU count is greater than 1 , I get an AssertionError.

config.update("jax_platform_name", "cpu")

# Set the environment variable
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=24'
# Verify it's set
print(f"XLA_FLAGS: {os.getenv('XLA_FLAGS')}")

print(jax.devices())
jax.local_device_count(backend='cpu') 

In the 'build jitted' section:

loss_fn_jitted = jax.jit(
    lambda rng, i, t, f: loss_fn.apply(params, state, rng, i, t, f)[0]
, backend='cpu')
grads_fn_jitted = jax.jit(grads_fn, backend='cpu')
run_forward_jitted = jax.jit(
    lambda rng, i, t, f: run_forward.apply(params, state, rng, i, t, f)[0]
, backend='cpu')
# We also produce a pmapped version for running in parallel.
run_forward_pmap = xarray_jax.pmap(run_forward_jitted, dim="sample", backend='cpu')

@andrewlkd Maybe #108 can be of some use, however, obviously, I don't understand how jax is working here with the CPUs.

When the cpu device count is set to 1, it uses all the CPUs anyway.

@dkokron
Copy link
Author

dkokron commented Jan 12, 2025

results from debugging so far are attached. I put a breakpoint in function chunked_prediction_generator() from rollout.py before predictor_fn(). I then printed out some variables looking for NaNs, then hit continue. The stack trace is in the attached text file. Please review and let me know if this help shed any light on how the NaNs are being generated.

Debugging.txt

@andrewlkd
Copy link
Collaborator

Hm, I'm not so sure this does shed light. This just suggests something in the actual predictor function (i.e. forward pass of GenCast) is causing NaNs when running on CPU.

In case it was something to do with the pmapping, I just tried on my end to run in the non pmapped case and it still produces NaNs.

Let me know if you get any more data points from debugging.

@dkokron
Copy link
Author

dkokron commented Jan 20, 2025

@andrewlkd PMAP is interfering with my debugging efforts. I'm running into the limitations described at
https://jax.readthedocs.io/en/latest/debugging/flags.html

Would you share your code changes to run the gencast_mini_demo.ipynb demo non-pmapped?

@andrewlkd
Copy link
Collaborator

Sure!

In the demo notebook, you'll want to:

  1. Redefine
run_forward_jitted = jax.jit(
    lambda rng, inputs, targets_template, forcings: run_forward.apply(params, state, rng, inputs, targets_template, forcings)[0]
)
  1. Pass in the non-pmapped predictor_fn and remove the pmap_devices argument to rollout.chunked_prediction_generator_multiple_runs:
for chunk in rollout.chunked_prediction_generator_multiple_runs(
    # Use pmapped version to parallelise across devices.
    # predictor_fn=run_forward_pmap,
    predictor_fn=run_forward_jitted,
    rngs=rngs,
    inputs=eval_inputs,
    targets_template=eval_targets * np.nan,
    forcings=eval_forcings,
    num_steps_per_chunk = 1,
    num_samples = num_ensemble_members,
    # pmap_devices=jax.local_devices()
    ):
    chunks.append(chunk)
predictions = xarray.combine_by_coords(chunks)

Hope this helps,

Andrew

@alvarosg
Copy link
Collaborator

chex.fake_pmap may also come in handy.

Ultimately though, what I find the most useful for debugging this kind of thing is to keep the pmap on, and then use jax.debug.print or [callbacks](https://jax.readthedocs.io/en/latest/external-callbacks.html) to run arbitrary python code on intermediate activations.

@dkokron
Copy link
Author

dkokron commented Jan 23, 2025

Eliminating the PMAP as Andrew suggested and adding these debug lines to top of the notebook after the imports results in the following trace. Setting these to False allows the code to run as before and generate NaNs.

jax.config.update("jax_debug_nans", True)
jax.config.update("jax_disable_jit", True)

.
.
.
.

File /scratch/dkokron/Projects/GenCast/graphcast/graphcast/gencast.py:285, in GenCast.__call__(self, inputs, targets_template, forcings, **kwargs)
    281 if self._sampler is None:
    282   self._sampler = dpm_solver_plus_plus_2s.Sampler(
    283       self._preconditioned_denoiser, **self._sampler_config
    284   )
--> 285 return self._sampler(inputs, targets_template, forcings, **kwargs)

File /scratch/dkokron/Projects/GenCast/graphcast/graphcast/dpm_solver_plus_plus_2s.py:186, in Sampler.__call__(self, inputs, targets_template, forcings, **kwargs)
    183 # Init with zeros but apply additional noise at step 0 to initialise the
    184 # state.
    185 noise_init = xarray.zeros_like(targets_template)
--> 186 return hk.fori_loop(
    187     0, len(noise_levels) - 1, body_fun=body_fn, init_val=noise_init)

File /scratch/dkokron/Projects/GenCast/graphcast/GenCast/lib/python3.12/site-packages/haiku/_src/stateful.py:697, in fori_loop(lower, upper, body_fun, init_val)
    695 state = internal_state()
    696 init_val = state, init_val
--> 697 state, val = jax.lax.fori_loop(lower, upper, pure_body_fun, init_val)
    698 update_internal_state(state)
    699 return val

File /scratch/dkokron/Projects/GenCast/graphcast/GenCast/lib/python3.12/site-packages/haiku/_src/stateful.py:677, in fori_loop.<locals>.pure_body_fun(i, val)
    674 state, val = val
    675 with temporary_internal_state(state), \
    676      base.push_jax_trace_level():
--> 677   val = body_fun(i, val)
    678   reserve_up_to_full_rng_block()
    679   state = internal_state()

File /scratch/dkokron/Projects/GenCast/graphcast/graphcast/dpm_solver_plus_plus_2s.py:134, in Sampler.__call__.<locals>.body_fn(i, x)
    126   return noise_levels[0] * utils.spherical_white_noise_like(template)
    128 # Initialise the inputs if i == 0.
    129 # This is done here to ensure both noise sampler calls can use the same
    130 # spherical harmonic basis functions. While there may be a small compute
    131 # cost the memory savings can be significant.
    132 # TODO(dominicmasters): Figure out if we can merge the two noise sampler
    133 # calls into one to avoid this hack.
--> 134 maybe_init_noise = (i == 0).astype(noise_levels[0].dtype)
    135 x = x + init_noise(x) * maybe_init_noise
    137 noise_level = noise_levels[i]

AttributeError: 'bool' object has no attribute 'astype'

@dkokron
Copy link
Author

dkokron commented Jan 23, 2025

forgot to mention that I also set the following in my test.
num_ensemble_members = 1

@dkokron
Copy link
Author

dkokron commented Jan 24, 2025

workaround for "AttributeError: 'bool' object has no attribute 'astype'"

maybe_init_noise = jnp.array([int(i == 0)]).astype(noise_levels[0].dtype)

@dkokron
Copy link
Author

dkokron commented Jan 24, 2025

Found it!
I added some debugging to dpm_solver_plus_plus_2s.py and denoiser.py (both attached).

The issue is that next_noise_level goes to zero at the 20th iteration (i=19) through body_fn(). That results in mid_noise_level going to zero. That zero gets passed to call in FourierFeaturesMLP() which applies the natural logarithm of that value. The full traceback is attached as well.

traceback.txt
dpm_solver_plus_plus_2s.py.txt
denoiser.py.txt

@andrewlkd
Copy link
Collaborator

Nice progress! Unfortunately, I'm not sure this is related to the NaN outputs you're seeing.

Note that when next_noise_level is 0, the sampler returns x_denoised not x_next

return utils.tree_where(next_noise_level == 0, x_denoised, x_next)

x_denoised is independent of mid_noise_level. I.e. the presence of these NaNs is expected. You'll want to skip over these NaNs in some fashion (I'd probably hack to not produce them in the last iteration) so that you can continue the debugging and see where the NaNs affecting the output are produced.

-- Andrew

@dkokron
Copy link
Author

dkokron commented Jan 24, 2025

Yeah, you are right. Coding around that ln(0.) still fails with NaNs downstream. Still looking.

@dkokron
Copy link
Author

dkokron commented Jan 26, 2025

@andrewlkd please try running your CPU test with the following setting just after cell 3. I was able to get a reasonable prediction with this setting.

jax.config.update("jax_disable_jit", True)

@dkokron
Copy link
Author

dkokron commented Jan 26, 2025

My system is running
jax: 0.4.38
jaxlib: 0.4.38
numpy: 2.2.0
python: 3.12.3 (main, Jan 17 2025, 18:03:48) [GCC 13.3.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='SerenityTwo', release='6.8.0-51-generic', version='#52-Ubuntu SMP PREEMPT_DYNAMIC Thu Dec 5 13:09:44 UTC 2024', machine='x86_64')

@andrewlkd
Copy link
Collaborator

Interesting. This suggests it's indeed an XLA-compilation related issue.

You may wish to compare the outputs of this with the outputs generated when running on the free Colab TPU to ensure that the forecasts being generated are indeed sensible.

@dkokron
Copy link
Author

dkokron commented Jan 27, 2025

I chased the issue into apply_stochastic_churn(). During the 16th iteration through body_fn(), "x" is fine after adding init_noise(x), but it goes bad after apply_stochastic_churn().

The relevant code is

  new_noise_level = noise_level * (1.0 + stochastic_churn_rate)
  extra_noise_stddev = (jnp.sqrt(new_noise_level**2 - noise_level**2)
                        * noise_level_inflation_factor)
  updated_x = x + spherical_white_noise_like(x) * extra_noise_stddev
  # Check 2m_temperature for NaNs
  L=(updated_x["2m_temperature"].isnull()).any()
  jax.debug.print("apply_stochastic_churn: x after applying stochastic_churn: 2m_temperature: {} {} {}", L, new_noise_level, extra_noise_stddev, ordered=True)

Note that per_step_churn_rates: [0.125 0.125 0.125 0.125 0.125 0.125 0.125 0.125 0.125 0.125 0.125 0.125 0.125 0.125 0. 0. 0. 0. 0. 0. ]

When stochastic_churn_rate=0.0, new_noise_level equals noise_level and "new_noise_level^2 - noise_level^2" equals zero. Taking the sqrt() of zero is fine, but I've read that this causes problems when computing the gradient. At any rate, extra_noise_stddev goes to nan as evidenced by the print I put at the bottom of apply_stochastic_churn().

apply_stochastic_churn: x after applying stochastic_churn: 2m_temperature: <xarray.DataArray '2m_temperature' ()> Size: 1B
xarray_jax.JaxArrayWrapper(Array(True, dtype=bool)) 0.22014613449573517 nan

@andrewlkd
Copy link
Collaborator

I've read that this causes problems when computing the gradient

Note that the sampler isn't being backpropagated through (it is only used at inference time), so I'm not sure gradients of this function are relevant here.

What about checking the value of new_noise_level**2 - noise_level**2? It's possible this is resolving to a value < 0 due to floating point numerics.

@dkokron
Copy link
Author

dkokron commented Jan 27, 2025

I'll look into that calculation.

@dkokron
Copy link
Author

dkokron commented Jan 27, 2025

That's it! The difference goes negative. Here is the code.

new_noise_level = noise_level * (1.0 + stochastic_churn_rate)
diff=new_noise_level^2 - noise_level^2
extra_noise_stddev = (jnp.sqrt(new_noise_level^2 - noise_level^2)
* noise_level_inflation_factor)
updated_x = x + spherical_white_noise_like(x) * extra_noise_stddev
L=(updated_x["2m_temperature"].isnull()).any()
jax.debug.print("apply_stochastic_churn: x after applying stochastic_churn: 2m_temperature: {} {} {} {} {}", L, new_noise_level, extra_noise_stddev, diff, noise_level_inflation_factor, ordered=True)

And output from the run

body_fn: x after init_noise: 2m_temperature: 16 <xarray.DataArray '2m_temperature' ()> Size: 1B
xarray_jax.JaxArrayWrapper(Array(False, dtype=bool))
body_fn: noise_level: 0.22014613449573517
apply_stochastic_churn: x after applying stochastic_churn: 2m_temperature: <xarray.DataArray '2m_temperature' ()> Size: 1B
xarray_jax.JaxArrayWrapper(Array(True, dtype=bool)) 0.22014613449573517 nan -6.661848850342267e-11 1.0499999523162842

@andrewlkd
Copy link
Collaborator

Nice!

Does adding something like

extra_noise = jnp.maximum(0, new_noise_level**2 - noise_level**2)
extra_noise_stddev = jnp.sqrt(extra_noise) * noise_level_inflation_factor

Generate non-NaN forecasts for you now?

@dkokron
Copy link
Author

dkokron commented Jan 27, 2025

The following modification resolves this issue.

diff=new_noise_level^2 - noise_level^2
#extra_noise_stddev = (jnp.sqrt(new_noise_level^2 - noise_level^2)
extra_noise_stddev = (jnp.sqrt(jnp.where(diff < 0., 0., diff))
* noise_level_inflation_factor)

@dkokron
Copy link
Author

dkokron commented Jan 27, 2025

Would the "where" solution or the "maximum" solution me more optimal?

@dkokron
Copy link
Author

dkokron commented Jan 27, 2025

jnp.maximum() is also a solution here.

extra_noise_stddev = (jnp.sqrt(jnp.maximum(0., diff))
* noise_level_inflation_factor)

@andrewlkd
Copy link
Collaborator

What does running

x = 0.22014613449573517 ** 2
x ** 2 - x ** 2

return on your end? It's a bit bizarre that this is happening...

@dkokron
Copy link
Author

dkokron commented Jan 27, 2025

I get 0.0 if I run that in a cell of my notebook.

@dkokron
Copy link
Author

dkokron commented Jan 27, 2025

I also get zero with

@jax.jit
def f(x1):
x2 = x1 * (1.0 + 0.0)
d = x2^2 - x1^2
jax.debug.print("{}",d)
return(d)

f(0.22014613449573517)

@dkokron
Copy link
Author

dkokron commented Jan 27, 2025

and with

def f(x1):
x2 = x1 * (1.0 + 0.0)
d = x2^2 - x1^2
jax.debug.print("{}",d)
return(d)
f_jit = jax.jit(f)
f(0.22014613449573517)
f_jit(0.22014613449573517)

@dkokron
Copy link
Author

dkokron commented Jan 27, 2025

def f(
  x1: jax.typing.ArrayLike,
  s: jax.typing.ArrayLike):
  x2 = x1 * (1.0 + s)
  d = x2**2 - x1**2
  jax.debug.print("{}",d)
  return(d)
f_jit = jax.jit(f)
f(0.22014613449573517, 0.0)
f_jit(0.22014613449573517, 0.0)

results in ....
0.0
-6.661848850342267e-11

Array(-6.661849e-11, dtype=float32, weak_type=True)

@dkokron
Copy link
Author

dkokron commented Jan 27, 2025

I see the same behavior under

jax: 0.5.0
jaxlib: 0.5.0
numpy: 2.2.2
python: 3.12.3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants