Skip to content

Commit

Permalink
- set tensorflow<=2.15 compatibility, higher versions do not work pro…
Browse files Browse the repository at this point in the history
…perly, do not train

- fixed bug with deleted samplers in utils
  • Loading branch information
sgrubas committed Aug 4, 2024
1 parent fac28e3 commit fac0393
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 5 deletions.
4 changes: 2 additions & 2 deletions NES/NeuralEikonalSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(self, xs, velocity, eikonal=None, name=None):
else:
assert isinstance(eikonal, L.Layer), "Eikonal should be an instance of keras Layer"
self.equation = eikonal
self.equation.name = f'{self.name}_{self.equation.name}'
self.equation._name = f'{self.name}_{self.equation.name}'

self.sing_eps = 1e-5 # tolerance for source singularity (to remove from training)
self.x_train = None # input training data
Expand Down Expand Up @@ -540,7 +540,7 @@ def __init__(self, velocity, eikonal=None, name=None):
else:
assert isinstance(eikonal, L.Layer), "Eikonal should be an instance of keras Layer"
self.equation = eikonal
self.equation.name = f'{self.name}_{self.equation.name}'
self.equation._name = f'{self.name}_{self.equation.name}'

self.sing_eps = 1e-5 # tolerance for source singularity (to remove from training)
self.x_train = None # input training data
Expand Down
3 changes: 1 addition & 2 deletions NES/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,7 @@ def data_handler(x, y, **kwargs):
callbacks = kwargs.get('callbacks', [])
generator_required = False
for c in callbacks:
if isinstance(c, (experimental.ImportanceSampling, experimental.ImportanceWeighting,
experimental.RARsampling, experimental.FromCoarseToFineResampling)):
if isinstance(c, (experimental.RARsampling, experimental.FromCoarseToFineResampling)):
generator_required = True
break

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
numpy<2.0.0
scipy>=1.13.0
tensorflow<=2.17.0
tensorflow<=2.15.0
setuptools>=70.0.0

0 comments on commit fac0393

Please sign in to comment.