diff --git a/_modules/hippynn/custom_kernels/env_triton.html b/_modules/hippynn/custom_kernels/env_triton.html new file mode 100644 index 00000000..8338d225 --- /dev/null +++ b/_modules/hippynn/custom_kernels/env_triton.html @@ -0,0 +1,457 @@ + + + + + + hippynn.custom_kernels.env_triton — hippynn 0+unknown documentation + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +

Source code for hippynn.custom_kernels.env_triton

+import torch
+import triton
+import triton.language as tl
+from .utils import resort_pairs_cached
+
+# If numba is available, this implementation will default to numba on CPU. If not, use vanilla pytorch.
+try:
+    from .env_numba import new_envsum as envsum_alternative, new_sensesum as sensesum_alternative, new_featsum as featsum_alternative
+except ImportError:
+    # Load backup implementation for CPU tensors.
+    from .env_pytorch import envsum as envsum_alternative, sensesum as sensesum_alternative, featsum as featsum_alternative
+
+
+[docs] +def config_pruner(configs, nargs, **kwargs): + """ + Trims the unnecessary config options based on the sens. and feat. sizes + """ + #print("For some reason the config pruner also gets arguments:",kwargs) + p2_sens_size = triton.next_power_of_2(nargs["sens_size"]) + p2_feat_size = triton.next_power_of_2(nargs["feat_size"]) + + used = set() + for config in configs: + + # Don't use block sizes bigger than p2_sens_size or p2_feat_size; they will give the same result + # because there will only be one block. + sense_block_size = min(p2_sens_size, config.kwargs["SENS_BLOCK_SIZE"]) + feat_block_size = min(p2_feat_size, config.kwargs["FEAT_BLOCK_SIZE"]) + + if (sense_block_size, feat_block_size, config.num_stages, config.num_warps) in used: + continue + + used.add((sense_block_size, feat_block_size, config.num_stages, config.num_warps)) + + yield triton.Config( + { + "SENS_BLOCK_SIZE": sense_block_size, + "FEAT_BLOCK_SIZE": feat_block_size, + }, + num_stages=config.num_stages, + num_warps=config.num_warps, + )
+ + +
+[docs] +def get_autotune_config(): + """ + Create a list of config options for the kernels + TODO: Need to spend time actually figuring out more reasonable options + targeted for modern GPUs + """ + return [ + triton.Config({"SENS_BLOCK_SIZE": 16, "FEAT_BLOCK_SIZE": 16}), + triton.Config({"SENS_BLOCK_SIZE": 16, "FEAT_BLOCK_SIZE": 32}), + triton.Config({"SENS_BLOCK_SIZE": 16, "FEAT_BLOCK_SIZE": 64}), + triton.Config({"SENS_BLOCK_SIZE": 16, "FEAT_BLOCK_SIZE": 128}), + triton.Config({"SENS_BLOCK_SIZE": 16, "FEAT_BLOCK_SIZE": 256}), + triton.Config({"SENS_BLOCK_SIZE": 32, "FEAT_BLOCK_SIZE": 32}), + triton.Config({"SENS_BLOCK_SIZE": 32, "FEAT_BLOCK_SIZE": 64}), + triton.Config({"SENS_BLOCK_SIZE": 32, "FEAT_BLOCK_SIZE": 128}), + triton.Config({"SENS_BLOCK_SIZE": 32, "FEAT_BLOCK_SIZE": 128}, num_warps=8), + triton.Config({"SENS_BLOCK_SIZE": 32, "FEAT_BLOCK_SIZE": 256}), + triton.Config({"SENS_BLOCK_SIZE": 32, "FEAT_BLOCK_SIZE": 256}, num_warps=8), + triton.Config({"SENS_BLOCK_SIZE": 64, "FEAT_BLOCK_SIZE": 32}), + triton.Config({"SENS_BLOCK_SIZE": 64, "FEAT_BLOCK_SIZE": 64}), + triton.Config({"SENS_BLOCK_SIZE": 64, "FEAT_BLOCK_SIZE": 128}), + triton.Config({"SENS_BLOCK_SIZE": 64, "FEAT_BLOCK_SIZE": 256}), + triton.Config({"SENS_BLOCK_SIZE": 128, "FEAT_BLOCK_SIZE": 32}), + triton.Config({"SENS_BLOCK_SIZE": 128, "FEAT_BLOCK_SIZE": 64}), + triton.Config({"SENS_BLOCK_SIZE": 128, "FEAT_BLOCK_SIZE": 64}, num_warps=8), + triton.Config({"SENS_BLOCK_SIZE": 256, "FEAT_BLOCK_SIZE": 32}), + triton.Config({"SENS_BLOCK_SIZE": 256, "FEAT_BLOCK_SIZE": 64}), + triton.Config({"SENS_BLOCK_SIZE": 256, "FEAT_BLOCK_SIZE": 64}, num_warps=8), + ]
+ + + +@triton.autotune(configs=get_autotune_config(), key=["sens_size", "feat_size"], prune_configs_by={"early_config_prune": config_pruner}) +@triton.jit +def envsum_kernel( + out_env_ptr, + sens_ptr, + feat_ptr, + psecond_ptr, + atom_ids_ptr, + atom_starts_ptr, + atom_size, + sens_size: tl.constexpr, + feat_size: tl.constexpr, + SENS_BLOCK_SIZE: tl.constexpr, + FEAT_BLOCK_SIZE: tl.constexpr, + dtype: tl.constexpr = tl.float32, +): + atom_id = tl.program_id(axis=0) + sens_id = tl.program_id(axis=1) + feat_id = tl.program_id(axis=2) + + valid_atom_id = atom_id < atom_size + + start = tl.load(atom_starts_ptr + atom_id, mask=valid_atom_id, other=0) + end = tl.load(atom_starts_ptr + atom_id + 1, mask=valid_atom_id, other=0) + target_id = tl.load(atom_ids_ptr + atom_id, mask=valid_atom_id, other=0) + + sens_block_ids = tl.arange(0, SENS_BLOCK_SIZE) + (sens_id * SENS_BLOCK_SIZE) + feat_block_ids = tl.arange(0, FEAT_BLOCK_SIZE) + (feat_id * FEAT_BLOCK_SIZE) + env_block_ids = sens_block_ids[:, None] * feat_size + feat_block_ids[None, :] + + valid_sens = sens_block_ids < sens_size + valid_feat = feat_block_ids < feat_size + valid_env = valid_sens[:, None] & valid_feat[None, :] + + tmp = tl.zeros((SENS_BLOCK_SIZE, FEAT_BLOCK_SIZE), dtype=dtype) + + for ind in range(start, end): + # [SENS_BLOCK_SIZE,], coming from the pair sensitivity + s = tl.load(sens_ptr + (ind * sens_size) + sens_block_ids, mask=valid_sens, other=0.0) + atom2_id = tl.load(psecond_ptr + ind) + # [FEAT_BLOCK_SIZE,], coming from the neighbor feature + feat = tl.load(feat_ptr + (atom2_id * feat_size) + feat_block_ids, mask=valid_feat, other=0.0) + # temp_mat and tmp is [SENS_BLOCK_SIZE, FEAT_BLOCK_SIZE] + temp_mat = s[:, None] * feat[None, :] + tmp = tmp + temp_mat + + atom_offset = target_id * sens_size * feat_size + + # TODO: use sparsity of sensitivities to reduce workload? (see numba envsum implementation) + tl.store(out_env_ptr + atom_offset + env_block_ids, tmp, mask=valid_env) + + +
+[docs] +def envsum_triton(sensitivities, features, pair_first, pair_second, atom_ids, atom_starts, out_env=None): + n_pairs, n_nu = sensitivities.shape + n_atom, n_feat = features.shape + (n_atom_with_pairs,) = atom_ids.shape + + if out_env is None: + out_env = torch.zeros((n_atom, n_nu, n_feat), dtype=features.dtype, device=features.device) + + dtype = tl.float32 + if features.dtype == torch.float64: + dtype = tl.float64 + + grid = lambda META: (n_atom_with_pairs, triton.cdiv(n_nu, META["SENS_BLOCK_SIZE"]), triton.cdiv(n_feat, META["FEAT_BLOCK_SIZE"])) + + envsum_kernel[grid]( + out_env, + sensitivities, + features, + pair_second, + atom_ids, + atom_starts, + n_atom_with_pairs, + n_nu, + n_feat, + dtype=dtype, + ) + return out_env
+ + + +
+[docs] +def envsum(sense, features, pfirst, psecond): + if sense.device == torch.device("cpu"): + return envsum_alternative(sense, features, pfirst, psecond) + psecond_hold = psecond + argsort, atom1_ids, atom1_starts, pfirst, (sense, psecond) = resort_pairs_cached(pfirst, [sense, psecond]) + resort_pairs_cached(psecond_hold, []) # Preemptively sort for backwards pass. + return envsum_triton(sense, features, pfirst, psecond, atom1_ids, atom1_starts, out_env=None)
+ + + +@triton.autotune(configs=get_autotune_config(), key=["sens_size", "feat_size"], prune_configs_by={"early_config_prune": config_pruner}) +@triton.jit +def sensesum_kernel( + out_sense_ptr, + env_ptr, + feat_ptr, + pfirst_ptr, + psecond_ptr, + pair_size, + sens_size: tl.constexpr, + feat_size: tl.constexpr, + SENS_BLOCK_SIZE: tl.constexpr, + FEAT_BLOCK_SIZE: tl.constexpr, + dtype: tl.constexpr = tl.float32, +): + pair_id = tl.program_id(axis=0) + sense_id = tl.program_id(axis=1) + num_feat_blocks: tl.constexpr = tl.cdiv(feat_size, FEAT_BLOCK_SIZE) + valid_pair = pair_id < pair_size + + first = tl.load(pfirst_ptr + pair_id, mask=valid_pair, other=0) + second = tl.load(psecond_ptr + pair_id, mask=valid_pair, other=0) + + sens_block_ids = tl.arange(0, SENS_BLOCK_SIZE) + (sense_id * SENS_BLOCK_SIZE) + feat_block_ids = tl.arange(0, FEAT_BLOCK_SIZE) + + valid_sens = sens_block_ids < sens_size + + tmp = tl.zeros((SENS_BLOCK_SIZE,), dtype=dtype) + for feat_id in range(num_feat_blocks): + valid_feat = feat_block_ids < feat_size + env_block_ids = sens_block_ids[:, None] * feat_size + feat_block_ids[None, :] + valid_env = valid_sens[:, None] & valid_feat[None, :] + # [SENS_BLOCK_SIZE, FEAT_BLOCK_SIZE] + env = tl.load(env_ptr + (first * sens_size * feat_size) + env_block_ids, mask=valid_env, other=0.0) + # [FEAT_BLOCK_SIZE, ] + feat = tl.load(feat_ptr + (second * feat_size) + feat_block_ids, mask=valid_feat, other=0.0) + # TODO: Here we use outer product followed by sum b/c built-in triton dot needs batches and FP<64. + # Can we make this better then? + # For future reference: + """ + type_f32: tl.constexpr = tl.float32 + type_check: tl.constexpr = (dtype == type_f32) + if type_check: + res = tl.dot(env, feat[:, None]) + else: + res = tl.sum(env * feat[None, :], axis=1) + """ + tmp += tl.sum(env * feat[None, :], axis=1) + # increment the feat block id + feat_block_ids += FEAT_BLOCK_SIZE + # TODO: use sparsity of sensitivities to reduce workload? (see numba envsum implementation) + tl.store(out_sense_ptr + (pair_id * sens_size) + sens_block_ids, tmp, mask=valid_sens) + + +
+[docs] +def sensesum(env, features, pair_first, pair_second, out_sense=None): + if env.device == torch.device("cpu"): + return sensesum_alternative(env, features, pair_first, pair_second) + + _, n_nu, _ = env.shape + n_atom, n_feat = features.shape + n_pairs = len(pair_first) + + if out_sense is None: + out_sense = torch.zeros((n_pairs, n_nu), dtype=features.dtype, device=features.device) + + dtype = tl.float32 + if features.dtype == torch.float64: + dtype = tl.float64 + + grid = lambda META: (n_pairs, triton.cdiv(n_nu, META["SENS_BLOCK_SIZE"])) + sensesum_kernel[grid](out_sense, env, features, pair_first, pair_second, n_pairs, n_nu, n_feat, dtype=dtype) + return out_sense
+ + + +@triton.autotune(configs=get_autotune_config(), key=["sens_size", "feat_size"], prune_configs_by={"early_config_prune": config_pruner}) +@triton.jit +def featsum_kernel( + out_feat, + env_ptr, + sens_ptr, + pfirst_ptr, + psecond_ptr, + atom2_ids_ptr, + atom2_starts_ptr, + atom_size, + sens_size: tl.constexpr, + feat_size: tl.constexpr, + SENS_BLOCK_SIZE: tl.constexpr, + FEAT_BLOCK_SIZE: tl.constexpr, + dtype: tl.constexpr = tl.float32, +): + atom_id = tl.program_id(axis=0) + feat_id = tl.program_id(axis=1) + num_sense_blocks: tl.constexpr = tl.cdiv(sens_size, SENS_BLOCK_SIZE) + valid_atom = atom_id < atom_size + + start = tl.load(atom2_starts_ptr + atom_id, mask=valid_atom, other=0) + end = tl.load(atom2_starts_ptr + atom_id + 1, mask=valid_atom, other=0) + target_id = tl.load(atom2_ids_ptr + atom_id, mask=valid_atom, other=0) + + feat_block_ids = tl.arange(0, FEAT_BLOCK_SIZE) + (feat_id * FEAT_BLOCK_SIZE) + + valid_feat = feat_block_ids < feat_size + + tmp = tl.zeros((FEAT_BLOCK_SIZE,), dtype=dtype) + + for ind in range(start, end): + sens_block_ids = tl.arange(0, SENS_BLOCK_SIZE) + for sense_id in range(num_sense_blocks): + valid_sens = sens_block_ids < sens_size + # [SENS_BLOCK_SIZE,], coming from the pair sensitivity + sense = tl.load(sens_ptr + (ind * sens_size) + sens_block_ids, mask=valid_sens, other=0.0) + atom1_ind = tl.load(pfirst_ptr + ind) + # [SENS_BLOCK_SIZE, FEAT_BLOCK_SIZE] + env_block_ids = sens_block_ids[:, None] * feat_size + feat_block_ids[None, :] + valid_env = valid_sens[:, None] & valid_feat[None, :] + env = tl.load(env_ptr + (atom1_ind * sens_size * feat_size) + env_block_ids, mask=valid_env, other=0.0) + # temp_mat and tmp is [FEAT_BLOCK_SIZE,] + temp_mat = tl.sum(env * sense[:, None], axis=0) + tmp = tmp + temp_mat + # increment the sense block id + sens_block_ids += SENS_BLOCK_SIZE + tl.store(out_feat + (target_id * feat_size) + feat_block_ids, tmp, mask=valid_feat) + + +
+[docs] +def featsum_triton(env, sense, pair_first, pair_second, atom2_ids, atom2_starts, out_feat=None): + + n_atom, n_nu, n_feat = env.shape + (n_pairs,) = pair_first.shape + (n_atoms_with_pairs,) = atom2_ids.shape + + if out_feat is None: + out_feat = torch.zeros((n_atom, n_feat), dtype=env.dtype, device=env.device) + + dtype = tl.float32 + if env.dtype == torch.float64: + dtype = tl.float64 + + grid = lambda META: (n_atoms_with_pairs, triton.cdiv(n_feat, META["FEAT_BLOCK_SIZE"])) + + featsum_kernel[grid]( + out_feat, + env, + sense, + pair_first, + pair_second, + atom2_ids, + atom2_starts, + n_atoms_with_pairs, + n_nu, + n_feat, + dtype=dtype, + ) + return out_feat
+ + + +
+[docs] +def featsum(env, sense, pfirst, psecond): + if env.device == torch.device("cpu"): + return featsum_alternative(env, sense, pfirst, psecond) + pfirst_hold = pfirst + argsort, atom2_ids, atom2_starts, psecond, (sense, pfirst) = resort_pairs_cached(psecond, [sense, pfirst]) + resort_pairs_cached(pfirst_hold, []) # preemptively sort (probably no-op) + return featsum_triton(env, sense, pfirst, psecond, atom2_ids, atom2_starts, out_feat=None)
+ +
+ +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/_modules/hippynn/custom_kernels/tensor_wrapper.html b/_modules/hippynn/custom_kernels/tensor_wrapper.html index 812f021f..789f7f5e 100644 --- a/_modules/hippynn/custom_kernels/tensor_wrapper.html +++ b/_modules/hippynn/custom_kernels/tensor_wrapper.html @@ -124,8 +124,8 @@

Source code for hippynn.custom_kernels.tensor_wrapper

[docs] def __init__(self): if numba.cuda.is_available(): - self.kernel64 = self.make_kernel(numba.float64) - self.kernel32 = self.make_kernel(numba.float32) + self.kernel64 = None + self.kernel32 = None else: self.kernel64 = _numba_gpu_not_found self.kernel32 = _numba_gpu_not_found
@@ -146,8 +146,12 @@

Source code for hippynn.custom_kernels.tensor_wrapper

with numba.cuda.gpus[dev.index]: numba_args = batch_convert_torch_to_numba(*args) if dtype == torch.float64: + if self.kernel64 is None: + self.kernel64 = self.make_kernel(numba.float64) self.kernel64[launch_bounds](*numba_args) elif dtype == torch.float32: + if self.kernel32 is None: + self.kernel32 = self.make_kernel(numba.float32) self.kernel32[launch_bounds](*numba_args) else: raise ValueError("Bad dtype: {}".format(dtype)) diff --git a/_modules/hippynn/custom_kernels/test_env_numba.html b/_modules/hippynn/custom_kernels/test_env_numba.html index fafa2ea7..77369fa9 100644 --- a/_modules/hippynn/custom_kernels/test_env_numba.html +++ b/_modules/hippynn/custom_kernels/test_env_numba.html @@ -204,6 +204,7 @@

Source code for hippynn.custom_kernels.test_env_numba

TEST_LARGE_PARAMS = dict(n_molecules=1000, n_atoms=30, atom_prob=0.7, n_features=80, n_nu=20) TEST_MEGA_PARAMS = dict(n_molecules=500, n_atoms=30, atom_prob=0.7, n_features=128, n_nu=100) TEST_ULTRA_PARAMS = dict(n_molecules=500, n_atoms=30, atom_prob=0.7, n_features=128, n_nu=320) +TEST_GIGA_PARAMS = dict(n_molecules=32, n_atoms=30, atom_prob=0.7, n_features=512, n_nu=320) # reference implementation @@ -566,6 +567,12 @@

Source code for hippynn.custom_kernels.test_env_numba

if use_verylarge_gpu: if use_ultra: + + print("-" * 80) + print("Giga systems:", TEST_GIGA_PARAMS) + tester.check_speed( + n_repetitions=20, data_size=TEST_GIGA_PARAMS, device=torch.device("cuda"), compare_against=compare_against + ) print("-" * 80) print("Ultra systems:", TEST_ULTRA_PARAMS) tester.check_speed( diff --git a/_modules/hippynn/databases/database.html b/_modules/hippynn/databases/database.html index 7fe0528b..d3a8c80e 100644 --- a/_modules/hippynn/databases/database.html +++ b/_modules/hippynn/databases/database.html @@ -79,6 +79,8 @@

Source code for hippynn.databases.database

 """
 Base database functionality from dictionary of numpy arrays
 """
+
+from typing import Union
 import warnings
 import numpy as np
 import torch
@@ -102,17 +104,18 @@ 

Source code for hippynn.databases.database

 [docs]
     def __init__(
         self,
-        arr_dict,
-        inputs,
-        targets,
-        seed,
-        test_size=None,
-        valid_size=None,
-        num_workers=0,
-        pin_memory=True,
-        allow_unfound=False,
-        auto_split=False,
-        device=None,
+        arr_dict: dict[str,torch.Tensor],
+        inputs: list[str],
+        targets: list[str],
+        seed: [int,np.random.RandomState,tuple],
+        test_size: Union[float,int]=None,
+        valid_size: Union[float,int]=None,
+        num_workers: int=0,
+        pin_memory: bool=True,
+        allow_unfound:bool =False,
+        auto_split:bool =False,
+        device: torch.device=None,
+        dataloader_kwargs:dict[str,object]=None,
         quiet=False,
     ):
         """
@@ -129,6 +132,9 @@ 

Source code for hippynn.databases.database

         :param allow_unfound: If true, skip checking if the needed inputs and targets are found.
            This allows setting inputs=None and/or targets=None.
         :param auto_split: If true, look for keys like "split_*" to make initial splits from. See write_npz() method.
+        :param device: if set, move the dataset to this device after splitting.
+        :param dataloader_kwargs: dictionary, passed to pytorch dataloaders in addition to num_workers, pin_memory.
+           Refer to pytorch documentation for details.
         :param quiet: If True, print little or nothing while loading.
         """
 
@@ -203,7 +209,9 @@ 

Source code for hippynn.databases.database

             if not self.splitting_completed:
                 raise ValueError("Device cannot be set in constructor unless automatic split provided.")
             else:
-                self.send_to_device(device)
+ self.send_to_device(device) + + self.dataloader_kwargs = dataloader_kwargs.copy() if dataloader_kwargs else {}
def __len__(self): @@ -534,6 +542,7 @@

Source code for hippynn.databases.database

             shuffle=shuffle,
             pin_memory=self.pin_memory,
             num_workers=self.num_workers,
+            **self.dataloader_kwargs,
         )
 
         return generator
@@ -632,7 +641,7 @@

Source code for hippynn.databases.database

 
 
[docs] - def write_npz(self, file: str, record_split_masks: bool = True, overwrite: bool = False, split_prefix=None, return_only=False): + def write_npz(self, file: str, record_split_masks: bool = True, compressed:bool =True, overwrite: bool = False, split_prefix=None, return_only=False): """ :param file: str, Path, or file object compatible with np.save :param record_split_masks: @@ -679,7 +688,10 @@

Source code for hippynn.databases.database

             if file.exists() and not overwrite:
                 raise FileExistsError(f"File exists: {file}")
 
-        np.savez_compressed(file, **arr_dict)
+        if compressed:
+            np.savez_compressed(file, **arr_dict)
+        else:
+            np.savez(file, **arr_dict)
 
         return arr_dict
diff --git a/_modules/hippynn/experiment/controllers.html b/_modules/hippynn/experiment/controllers.html index a724cc3e..336bd205 100644 --- a/_modules/hippynn/experiment/controllers.html +++ b/_modules/hippynn/experiment/controllers.html @@ -84,7 +84,6 @@

Source code for hippynn.experiment.controllers

from torch.optim.lr_scheduler import ReduceLROnPlateau -

[docs] class Controller: @@ -133,12 +132,10 @@

Source code for hippynn.experiment.controllers

fraction_train_eval=0.1, quiet=False, ): + super().__init__() self.optimizer = optimizer - self.scheduler = scheduler - self.stopping_key = stopping_key - self.batch_size = batch_size self.eval_batch_size = eval_batch_size or batch_size if max_epochs is None: @@ -170,7 +167,8 @@

Source code for hippynn.experiment.controllers

[docs] def state_dict(self): state_dict = {k: getattr(self, k) for k in self._state_vars} - state_dict["optimizer"] = self.optimizer.state_dict() + if self.optimizer is not None: + state_dict["optimizer"] = self.optimizer.state_dict() state_dict["scheduler"] = [sch.state_dict() for sch in self.scheduler_list] return state_dict

@@ -182,7 +180,8 @@

Source code for hippynn.experiment.controllers

for sch, sdict in zip(self.scheduler_list, state_dict["scheduler"]): sch.load_state_dict(sdict) - self.optimizer.load_state_dict(state_dict["optimizer"]) + if self.optimizer is not None: + self.optimizer.load_state_dict(state_dict["optimizer"]) for k in self._state_vars: setattr(self, k, state_dict[k])

@@ -194,7 +193,7 @@

Source code for hippynn.experiment.controllers

[docs] - def push_epoch(self, epoch, better_model, metric): + def push_epoch(self, epoch, better_model, metric, _print=print): self.current_epoch += 1 if better_model: @@ -209,8 +208,9 @@

Source code for hippynn.experiment.controllers

sch.step() if not self.quiet: - print("Epochs since last best:", self.boredom) - print("Current max epochs:", self.max_epochs) + _print("Epochs since last best:", self.boredom) + _print("Current max epochs:", self.max_epochs) + return self.current_epoch < self.max_epochs

@@ -239,27 +239,31 @@

Source code for hippynn.experiment.controllers

[docs] - def push_epoch(self, epoch, better_model, metric): + def push_epoch(self, epoch, better_model, metric, _print=print): if better_model: if self.boredom > 0 and not self.quiet: - print("Patience for training restored.") + _print("Patience for training restored.") self.boredom = 0 self.last_best = epoch - return super().push_epoch(epoch, better_model, metric)

+ return super().push_epoch(epoch, better_model, metric, _print=_print)
@property def max_epochs(self): - return min(self.last_best + self.patience, self._max_epochs)
+ return min(self.last_best + self.patience + 1, self._max_epochs)
+# Developer note: The inheritance here is only so that pytorch lightning +# readily identifies this as a scheduler.
[docs] -class RaiseBatchSizeOnPlateau: +class RaiseBatchSizeOnPlateau(ReduceLROnPlateau): """ Learning rate scheduler compatible with pytorch schedulers. + Note: The "VERBOSE" Parameter has been deprecated and no longer does anything. + This roughly implements the scheme outlined in the following paper: .. code-block:: none @@ -288,9 +292,20 @@

Source code for hippynn.experiment.controllers

patience=10, threshold=0.0001, threshold_mode="rel", - verbose=True, + verbose=None, # DEPRECATED controller=None, ): + """ + + :param optimizer: + :param max_batch_size: + :param factor: + :param patience: + :param threshold: + :param threshold_mode: + :param verbose: + :param controller: + """ if threshold_mode not in ("abs", "rel"): raise ValueError("Mode must be 'abs' or 'rel'") @@ -301,14 +316,18 @@

Source code for hippynn.experiment.controllers

factor=factor, threshold=threshold, threshold_mode=threshold_mode, - verbose=verbose, ) self.controller = controller self.max_batch_size = max_batch_size self.best_metric = float("inf") self.boredom = 0 - self.last_epoch = 0

+ self.last_epoch = 0 + warnings.warn("Parameter verbose no longer supported for schedulers. It will be ignored.")
+ + @property + def optimizer(self): + return self.inner.optimizer
[docs] @@ -368,12 +387,9 @@

Source code for hippynn.experiment.controllers

new_batch_size = min(new_batch_size, self.max_batch_size) self.controller.batch_size = new_batch_size self.boredom = 0 - if self.inner.verbose: - print("Raising batch size to", new_batch_size) + if new_batch_size >= self.max_batch_size: self.inner.last_epoch = self.last_epoch - 1 - if self.inner.verbose: - print("Max batch size reached, Lowering learning rate from here.") return

diff --git a/_modules/hippynn/experiment/lightning_trainer.html b/_modules/hippynn/experiment/lightning_trainer.html new file mode 100644 index 00000000..ad7282ee --- /dev/null +++ b/_modules/hippynn/experiment/lightning_trainer.html @@ -0,0 +1,543 @@ + + + + + + hippynn.experiment.lightning_trainer — hippynn 0+unknown documentation + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • + + +
  • +
  • +
+
+
+
+
+ +

Source code for hippynn.experiment.lightning_trainer

+"""
+Pytorch Lightning training interface.
+
+This module is somewhat experimental. Using pytorch lightning
+successfully in a distributed context may require understanding
+and adjusting the various settings related to parallelism, e.g.
+multiprocessing context, torch ddp backend, and how they interact
+with your HPC environment.
+
+Some features of hippynn experiments may not be implemented yet.
+    - The plotmaker is currently not supported.
+
+"""
+import warnings
+import copy
+from pathlib import Path
+
+import torch
+
+import pytorch_lightning as pl
+
+from .routines import TrainingModules
+from ..databases import Database
+from .routines import SetupParams, setup_training
+from ..graphs import GraphModule
+from .controllers import Controller
+from .metric_tracker import MetricTracker
+from .step_functions import get_step_function, StandardStep
+from ..tools import print_lr
+from . import serialization
+
+
+
+[docs] +class HippynnLightningModule(pl.LightningModule): +
+[docs] + def __init__( + self, + model: GraphModule, + loss: GraphModule, + eval_loss: GraphModule, + eval_names: list[str], + stopping_key: str, + optimizer_list: list[torch.optim.Optimizer], + scheduler_list: list[torch.optim.lr_scheduler], + controller: Controller, + metric_tracker: MetricTracker, + inputs: list[str], + targets: list[str], + n_outputs: int, + *args, + **kwargs, + ): # forwards args and kwargs to where? + super().__init__() + + self.save_hyperparameters(ignore=["loss", "model", "eval_loss", "controller", "optimizer_list", "scheduler_list"]) + + self.model = model + self.loss = loss + self.eval_loss = eval_loss + self.eval_names = eval_names + self.stopping_key = stopping_key + self.controller = controller + self.metric_tracker = metric_tracker + self.optimizer_list = optimizer_list + self.scheduler_list = scheduler_list + self.inputs = inputs + self.targets = targets + self.n_inputs = len(self.inputs) + self.n_targets = len(self.targets) + self.n_outputs = n_outputs + + self.structure_file = None + + self._last_reload_dlene = None # storage for whether batch size should be changed. + + # Storage for predictions across batches for eval mode. + self.eval_step_outputs = [] + self.controller.optimizer = None + + for optimizer in self.optimizer_list: + if not isinstance(step_fn := get_step_function(optimizer), StandardStep): # := + raise NotImplementedError(f"Optimzers with non-standard steps are not yet supported. {optimizer,step_fn}") + + if args or kwargs: + raise NotImplementedError("Generic args and kwargs not supported.")
+ + +
+[docs] + @classmethod + def from_experiment_setup(cls, training_modules: TrainingModules, database: Database, setup_params: SetupParams, **kwargs): + training_modules, controller, metric_tracker = setup_training(training_modules, setup_params) + return cls.from_train_setup(training_modules, database, controller, metric_tracker, **kwargs)
+ + +
+[docs] + @classmethod + def from_train_setup( + cls, + training_modules: TrainingModules, + database: Database, + controller: Controller, + metric_tracker: MetricTracker, + callbacks=None, + batch_callbacks=None, + **kwargs, + ): + + model, loss, evaluator = training_modules + + warnings.warn("PytorchLightning hippynn trainer is still experimental.") + + if evaluator.plot_maker is not None: + warnings.warn("plot_maker is not currently supported in pytorch lightning. The current plot_maker will be ignored.") + + trainer = cls( + model=model, + loss=loss, + eval_loss=evaluator.loss, + eval_names=evaluator.loss_names, + optimizer_list=[controller.optimizer], + scheduler_list=controller.scheduler_list, + stopping_key=controller.stopping_key, + controller=controller, + metric_tracker=metric_tracker, + inputs=database.inputs, + targets=database.targets, + n_outputs=evaluator.n_outputs, + **kwargs, + ) + + # pytorch lightning is now in charge of stepping the scheduler. + controller.scheduler_list = [] + + if callbacks is not None or batch_callbacks is not None: + return NotImplemented("arbitrary callbacks are not yet supported with pytorch lightning.") + + return trainer, HippynnDataModule(database, controller.batch_size)
+ + +
+[docs] + def on_save_checkpoint(self, checkpoint) -> None: + + # Note to future developers: + # trainer.log_dir property needs to be called on all ranks! This is weird but important; + # do not move trainer.log_dir inside of a rank zero operation! + # see https://github.com/Lightning-AI/pytorch-lightning/discussions/8321 + # Thank you to https://github.com/semaphore-egg . + log_dir = self.trainer.log_dir + + if not self.structure_file: + # Perform change on all ranks. + sf = serialization.DEFAULT_STRUCTURE_FNAME + self.structure_file = sf + + if self.global_rank == 0 and not self.structure_file: + self.print("creating structure file.") + structure = dict( + model=self.model, + loss=self.loss, + eval_loss=self.eval_loss, + controller=self.controller, + optimizer_list=self.optimizer_list, + scheduler_list=self.scheduler_list, + ) + path: Path = Path(log_dir).joinpath(sf) + self.print("Saving structure file at", path) + torch.save(obj=structure, f=path) + + checkpoint["controller_state"] = self.controller.state_dict() + return
+ + +
+[docs] + @classmethod + def load_from_checkpoint(cls, checkpoint_path, map_location=None, structure_file=None, hparams_file=None, strict=True, **kwargs): + + if structure_file is None: + # Assume checkpoint_path is like <model_name>/version_<n>/checkpoints/<something>.chkpt + # and that experiment file is stored at <model_name>/version_<n>/experiment_structure.pt + structure_file = Path(checkpoint_path) + structure_file = structure_file.parent.parent + structure_file = structure_file.joinpath(serialization.DEFAULT_STRUCTURE_FNAME) + + structure_args = torch.load(structure_file) + + return super().load_from_checkpoint( + checkpoint_path, map_location=map_location, hparams_file=hparams_file, strict=strict, **structure_args, **kwargs + )
+ + +
+[docs] + def on_load_checkpoint(self, checkpoint) -> None: + cstate = checkpoint.pop("controller_state") + self.controller.load_state_dict(cstate) + return
+ + +
+[docs] + def configure_optimizers(self): + + scheduler_list = [] + for s in self.scheduler_list: + config = { + "scheduler": s, + "interval": "epoch", # can be epoch or step + "frequency": 1, # How many intervals should pass between calls to `scheduler.step()`. + "monitor": "valid_" + self.stopping_key, # Metric to monitor for schedulers like `ReduceLROnPlateau` + "strict": True, + "name": "learning_rate", + } + scheduler_list.append(config) + + optimizer_list = self.optimizer_list.copy() + + return optimizer_list, scheduler_list
+ + +
+[docs] + def on_train_epoch_start(self): + for optimizer in self.optimizer_list: + print_lr(optimizer, print_=self.print) + self.print("Batch size:", self.trainer.train_dataloader.batch_size)
+ + +
+[docs] + def training_step(self, batch, batch_idx): + + batch_inputs = batch[: self.n_inputs] + batch_targets = batch[-self.n_targets :] + + batch_model_outputs = self.model(*batch_inputs) + batch_train_loss = self.loss(*batch_model_outputs, *batch_targets)[0] + + self.log("train_loss", batch_train_loss) + return batch_train_loss
+ + + def _eval_step(self, batch, batch_idx): + + batch_inputs = batch[: self.n_inputs] + batch_targets = batch[-self.n_targets :] + + # It is very, very common to fit to derivatives, e.g. force, in hippynn. Override lightning default. + with torch.autograd.set_grad_enabled(True): + batch_predictions = self.model(*batch_inputs) + + batch_predictions = [bp.detach() for bp in batch_predictions] + + outputs = (batch_predictions, batch_targets) + self.eval_step_outputs.append(outputs) + return batch_predictions + +
+[docs] + def validation_step(self, batch, batch_idx): + return self._eval_step(batch, batch_idx)
+ + +
+[docs] + def test_step(self, batch, batch_idx): + return self._eval_step(batch, batch_idx)
+ + + def _eval_epoch_end(self, prefix): + + all_batch_predictions, all_batch_targets = zip(*self.eval_step_outputs) + # now 'shape' (n_batch, n_outputs) -> need to transpose. + all_batch_predictions = [[bpred[i] for bpred in all_batch_predictions] for i in range(self.n_outputs)] + # now 'shape' (n_batch, n_targets) -> need to transpose. + all_batch_targets = [[bpred[i] for bpred in all_batch_targets] for i in range(self.n_targets)] + + # now cat each prediction and target across the batch index. + all_predictions = [torch.cat(x, dim=0) if x[0].shape != () else x[0] for x in all_batch_predictions] + all_targets = [torch.cat(x, dim=0) for x in all_batch_targets] + + all_losses = [x.item() for x in self.eval_loss(*all_predictions, *all_targets)] + self.eval_step_outputs.clear() # free memory + + loss_dict = {name: value for name, value in zip(self.eval_names, all_losses)} + + self.log_dict({prefix + k: v for k, v in loss_dict.items()}, sync_dist=True) + + return + +
+[docs] + def on_validation_epoch_end(self): + self._eval_epoch_end(prefix="valid_") + return
+ + +
+[docs] + def on_test_epoch_end(self): + self._eval_epoch_end(prefix="test_") + return
+ + + def _eval_end(self, prefix, when=None) -> None: + if when is None: + if self.trainer.sanity_checking: + when = "Sanity Check" + else: + when = self.current_epoch + + # Step 1: get metrics reduced from all ranks. + # Copied pattern from pytorch_lightning. + metrics = copy.deepcopy(self.trainer.callback_metrics) + + pre_len = len(prefix) + loss_dict = {k[pre_len:]: v.item() for k, v in metrics.items() if k.startswith(prefix)} + + loss_dict = {prefix[:-1]: loss_dict} # strip underscore from prefix and wrap. + + if self.trainer.sanity_checking: + self.print("Sanity check metric values:") + self.metric_tracker.evaluation_print(loss_dict, _print=self.print) + return + + # Step 2: register metrics + out_ = self.metric_tracker.register_metrics(loss_dict, when=when) + better_metrics, better_model, stopping_metric = out_ + self.metric_tracker.evaluation_print_better(loss_dict, better_metrics, _print=self.print) + + continue_training = self.controller.push_epoch(self.current_epoch, better_model, stopping_metric, _print=self.print) + + if not continue_training: + self.print("Controller is terminating training.") + self.trainer.should_stop = True + + # Step 3: Logic for changing the batch size without always requiring new dataloaders. + # Step 3a: don't do this when not testing. + if not self.trainer.training: + return + + controller_batch_size = self.controller.batch_size + trainer_batch_size = self.trainer.train_dataloader.batch_size + if controller_batch_size != trainer_batch_size: + # Need to trigger a batch size change. + if self._last_reload_dlene is None: + # save the original value of this variable to the pl module + self._last_reload_dlene = self.trainer.reload_dataloaders_every_n_epochs + + # TODO: Make this run even if there isn't an explicit datamodule? + self.trainer.datamodule.batch_size = controller_batch_size + # Tell PL lightning to reload the dataloaders now. + self.trainer.reload_dataloaders_every_n_epochs = 1 + + elif self._last_reload_dlene is not None: + # Restore the last saved value from the pl module. + self.trainer.reload_dataloaders_every_n_epochs = self._last_reload_dlene + self._last_reload_dlene = None + else: + # Batch sizes match, and there's no variable to restore. + pass + return + +
+[docs] + def on_validation_end(self): + self._eval_end(prefix="valid_") + return
+ + +
+[docs] + def on_test_end(self): + self._eval_end(prefix="test_", when="test") + return
+
+ + + +
+[docs] +class LightingPrintStagesCallback(pl.Callback): + """ + This callback is for debugging only. + It prints whenever a callback stage is entered in pytorch lightning. + """ + + for k in dir(pl.Callback): + if k.startswith("on_"): + + def some_method(self, *args, _k=k, **kwargs): + all_args = kwargs.copy() + all_args.update({i: a for i, a in enumerate(args)}) + int_args = {k: v for k, v in all_args.items() if isinstance(v, int)} + print("Callback stage:", _k, "with integer arguments:", int_args) + + exec(f"{k} = some_method") + del some_method
+ + + +
+[docs] +class HippynnDataModule(pl.LightningDataModule): +
+[docs] + def __init__(self, database: Database, batch_size): + super().__init__() + self.database = database + self.batch_size = batch_size
+ + +
+[docs] + def train_dataloader(self): + return self.database.make_generator("train", "train", self.batch_size)
+ + +
+[docs] + def val_dataloader(self): + return self.database.make_generator("valid", "eval", self.batch_size)
+ + +
+[docs] + def test_dataloader(self): + return self.database.make_generator("test", "eval", self.batch_size)
+
+ +
+ +
+
+
+ +
+ +
+

© Copyright 2019, Los Alamos National Laboratory.

+
+ + Built with Sphinx using a + theme + provided by Read the Docs. + + +
+
+
+
+
+ + + + \ No newline at end of file diff --git a/_modules/hippynn/experiment/metric_tracker.html b/_modules/hippynn/experiment/metric_tracker.html index 3b738653..2ec760a2 100644 --- a/_modules/hippynn/experiment/metric_tracker.html +++ b/_modules/hippynn/experiment/metric_tracker.html @@ -173,7 +173,6 @@

Source code for hippynn.experiment.metric_tracker

except KeyError: if split_type not in self.best_metric_values: # Haven't seen this split before! - print("ADDING ",split_type) self.best_metric_values[split_type] = {} better_metrics[split_type] = {} better = True # old best was not found! @@ -187,7 +186,7 @@

Source code for hippynn.experiment.metric_tracker

else: self.other_metric_values[when] = metric_info - if self.stopping_key: + if self.stopping_key and "valid" in metric_info: better_model = better_metrics.get("valid", {}).get(self.stopping_key, False) stopping_key_metric = metric_info["valid"][self.stopping_key] else: @@ -199,24 +198,24 @@

Source code for hippynn.experiment.metric_tracker

[docs] - def evaluation_print(self, evaluation_dict, quiet=None): + def evaluation_print(self, evaluation_dict, quiet=None, _print=print): if quiet is None: quiet = self.quiet if quiet: return - table_evaluation_print(evaluation_dict, self.metric_names, self.name_column_width)
+ table_evaluation_print(evaluation_dict, self.metric_names, self.name_column_width, _print=_print)
[docs] - def evaluation_print_better(self, evaluation_dict, better_dict, quiet=None): + def evaluation_print_better(self, evaluation_dict, better_dict, quiet=None, _print=print): if quiet is None: quiet = self.quiet if quiet: return - table_evaluation_print_better(evaluation_dict, better_dict, self.metric_names, self.name_column_width) + table_evaluation_print_better(evaluation_dict, better_dict, self.metric_names, self.name_column_width, _print=print) if self.stopping_key: - print( + _print( "Best {} so far: {:>8.5g}".format( self.stopping_key, self.best_metric_values["valid"][self.stopping_key] ) @@ -235,7 +234,7 @@

Source code for hippynn.experiment.metric_tracker

# Decoupled from the estate in case we want to more easily change print formatting.
[docs] -def table_evaluation_print_better(evaluation_dict, better_dict, metric_names, n_columns): +def table_evaluation_print_better(evaluation_dict, better_dict, metric_names, n_columns, _print=print): """ Print metric results as a table, add a '*' character for metrics in better_dict. @@ -258,11 +257,11 @@

Source code for hippynn.experiment.metric_tracker

header = " " * (n_columns + 2) + "".join("{:>14}".format(tn) for tn in type_names) rowstring = "{:<" + str(n_columns) + "}: " + " {}{:>10.5g}" * n_types - print(header) - print("-" * len(header)) + _print(header) + _print("-" * len(header)) for n, valsbet in zip(metric_names, transposed_values_better): rowoutput = [k for bv in valsbet for k in bv] - print(rowstring.format(n, *rowoutput))
+ _print(rowstring.format(n, *rowoutput))
@@ -270,7 +269,7 @@

Source code for hippynn.experiment.metric_tracker

# Decoupled from the estate in case we want to more easily change print formatting.
[docs] -def table_evaluation_print(evaluation_dict, metric_names, n_columns): +def table_evaluation_print(evaluation_dict, metric_names, n_columns, _print=print): """ Print metric results as a table. @@ -288,11 +287,11 @@

Source code for hippynn.experiment.metric_tracker

header = " " * (n_columns + 2) + "".join("{:>14}".format(tn) for tn in type_names) rowstring = "{:<" + str(n_columns) + "}: " + " {:>10.5g}" * n_types - print(header) - print("-" * len(header)) + _print(header) + _print("-" * len(header)) for n, vals in zip(metric_names, transposed_values): - print(rowstring.format(n, *vals)) - print("-" * len(header))
+ _print(rowstring.format(n, *vals)) + _print("-" * len(header))
diff --git a/_modules/hippynn/experiment/routines.html b/_modules/hippynn/experiment/routines.html index 8787754e..457fc1e7 100644 --- a/_modules/hippynn/experiment/routines.html +++ b/_modules/hippynn/experiment/routines.html @@ -395,9 +395,7 @@

Source code for hippynn.experiment.routines

         print("Finishing up...")
     print("Training phase ended.")
 
-    if store_metrics:
-        with open("training_metrics.pkl", "wb") as pfile:
-            pickle.dump(metric_tracker, pfile)
+    torch.save(metric_tracker, "training_metrics.pt")
 
     best_model = metric_tracker.best_model
     if best_model:
@@ -543,6 +541,7 @@ 

Source code for hippynn.experiment.routines

         qprint("_" * 50)
         qprint("Epoch {}:".format(epoch))
         tools.print_lr(optimizer)
+        qprint("Batch Size:", controller.batch_size)
 
         qprint(flush=True, end="")
 
diff --git a/_modules/hippynn/experiment/serialization.html b/_modules/hippynn/experiment/serialization.html
index e428f4db..3d523e12 100644
--- a/_modules/hippynn/experiment/serialization.html
+++ b/_modules/hippynn/experiment/serialization.html
@@ -77,7 +77,9 @@
              
   

Source code for hippynn.experiment.serialization

 """
-checkpoint and state generation
+Checkpoint and state generation.
+
+As a user, in most cases you will only need the `load` functions here.
 """
 
 from typing import Tuple, Union
@@ -90,7 +92,7 @@ 

Source code for hippynn.experiment.serialization

from ..graphs import GraphModule from ..tools import device_fallback from .assembly import TrainingModules -from .controllers import PatienceController +from .controllers import Controller from .device import set_devices from .metric_tracker import MetricTracker @@ -101,13 +103,13 @@

Source code for hippynn.experiment.serialization

[docs] def create_state( model: GraphModule, - controller: PatienceController, + controller: Controller, metric_tracker: MetricTracker, ) -> dict: """Create an experiment state dictionary. :param model: current model - :param controller: patience controller + :param controller: controller :param metric_tracker: current metrics :return: dictionary containing experiment state. :rtype: dict @@ -126,7 +128,7 @@

Source code for hippynn.experiment.serialization

def create_structure_file( training_modules: TrainingModules, database: Database, - controller: PatienceController, + controller: Controller, fname=DEFAULT_STRUCTURE_FNAME, ) -> None: """ @@ -134,7 +136,7 @@

Source code for hippynn.experiment.serialization

:param training_modules: contains model, controller, and loss :param database: database for training - :param controller: patience controller + :param controller: controller :param fname: filename to save the checkpoint :return: None diff --git a/_modules/hippynn/graphs/gops.html b/_modules/hippynn/graphs/gops.html index a8aedce0..2b28fe68 100644 --- a/_modules/hippynn/graphs/gops.html +++ b/_modules/hippynn/graphs/gops.html @@ -133,7 +133,8 @@

Source code for hippynn.graphs.gops

 
     evaluation_inputs_list = []
     evaluation_outputs_list = []
-    unsatisfied_nodes = all_nodes.copy()
+    # need to sort to get stable results between runs/processes.
+    unsatisfied_nodes = list(sorted(all_nodes, key=lambda node: node.name))
     satisfied_nodes = set()
     n = -1
     while len(unsatisfied_nodes) > 0:
diff --git a/_modules/hippynn/interfaces/ase_interface/ase_database.html b/_modules/hippynn/interfaces/ase_interface/ase_database.html
index 5355c7a2..a77d015a 100644
--- a/_modules/hippynn/interfaces/ase_interface/ase_database.html
+++ b/_modules/hippynn/interfaces/ase_interface/ase_database.html
@@ -102,14 +102,14 @@ 

Source code for hippynn.interfaces.ase_interface.ase_database

import os import numpy as np -from ase.io import read +from ase.io import read, iread -from ...tools import np_of_torchdefaultdtype +from ...tools import np_of_torchdefaultdtype, progress_bar from ...databases.database import Database from ...databases.restarter import Restartable from typing import Union from typing import List - +import hippynn.tools
[docs] @@ -169,11 +169,11 @@

Source code for hippynn.interfaces.ase_interface.ase_database

var_list = inputs + targets try: if isinstance(filename, str): - db = read(directory + filename, index=":") + db = list(progress_bar(iread(directory+filename,index=":"), desc='configs'))#read(directory + filename, index=":") elif isinstance(filename, (list, np.ndarray)): db = [] - for name in filename: - temp_db = read(directory + name, index=":") + for name in progress_bar(filename, desc='files'): + temp_db = list(progress_bar(iread(directory + name, index=":"), desc='configs')) db += temp_db except FileNotFoundError as fee: raise FileNotFoundError( diff --git a/_modules/hippynn/layers/hiplayers.html b/_modules/hippynn/layers/hiplayers.html index 5167421c..6f8adc85 100644 --- a/_modules/hippynn/layers/hiplayers.html +++ b/_modules/hippynn/layers/hiplayers.html @@ -426,16 +426,26 @@

Source code for hippynn.layers.hiplayers

         n_atoms_real = in_features.shape[0]
         sense_vals = self.sensitivity(dist_pairs)
 
+        # Sensitivity stacking
+        sense_vec = sense_vals.unsqueeze(1) * (coord_pairs / dist_pairs.unsqueeze(1)).unsqueeze(2)
+        sense_vec = sense_vec.reshape(-1, self.n_dist * 3)
+        sense_stacked = torch.concatenate([sense_vals, sense_vec], dim=1)
+
+        # Message passing, stack sensitivities to coalesce custom kernel call.
+        # shape (n_atoms, n_nu + 3*n_nu, n_feat)
+        env_features_stacked = custom_kernels.envsum(sense_stacked, in_features, pair_first, pair_second)
+        # shape (n_atoms, 4, n_nu, n_feat)
+        env_features_stacked = env_features_stacked.reshape(-1, 4, self.n_dist, self.nf_in)
+
+        # separate to tensor components
+        env_features, env_features_vec = torch.split(env_features_stacked, [1, 3], dim=1)
+
         # Scalar part
-        env_features = custom_kernels.envsum(sense_vals, in_features, pair_first, pair_second)
         env_features = torch.reshape(env_features, (n_atoms_real, self.n_dist * self.nf_in))
         weights_rs = torch.reshape(self.int_weights.permute(0, 2, 1), (self.n_dist * self.nf_in, self.nf_out))
         features_out = torch.mm(env_features, weights_rs)
 
         # Vector part
-        sense_vec = sense_vals.unsqueeze(1) * (coord_pairs / dist_pairs.unsqueeze(1)).unsqueeze(2)
-        sense_vec = sense_vec.reshape(-1, self.n_dist * 3)
-        env_features_vec = custom_kernels.envsum(sense_vec, in_features, pair_first, pair_second)
         env_features_vec = env_features_vec.reshape(n_atoms_real * 3, self.n_dist * self.nf_in)
         features_out_vec = torch.mm(env_features_vec, weights_rs)
         features_out_vec = features_out_vec.reshape(n_atoms_real, 3, self.nf_out)
@@ -475,19 +485,41 @@ 

Source code for hippynn.layers.hiplayers

         n_atoms_real = in_features.shape[0]
         sense_vals = self.sensitivity(dist_pairs)
 
-        # Scalar part
-        env_features = custom_kernels.envsum(sense_vals, in_features, pair_first, pair_second)
+        ####
+        # Sensitivity calculations
+        # scalar: sense_vals
+        # vector: sense_vec
+        # quadrupole: sense_quad
+        rhats = coord_pairs / dist_pairs.unsqueeze(1)
+        sense_vec = sense_vals.unsqueeze(1) * rhats.unsqueeze(2)
+        sense_vec = sense_vec.reshape(-1, self.n_dist * 3)
+        rhatsquad = rhats.unsqueeze(1) * rhats.unsqueeze(2)
+        rhatsquad = (rhatsquad + rhatsquad.transpose(1, 2)) / 2
+        tr = torch.diagonal(rhatsquad, dim1=1, dim2=2).sum(dim=1) / 3.0  # Add divide by 3 early to save flops
+        tr = tr.unsqueeze(1).unsqueeze(2) * torch.eye(3, dtype=tr.dtype, device=tr.device).unsqueeze(0)
+        rhatsquad = rhatsquad - tr
+        rhatsqflat = rhatsquad.reshape(-1, 9)[:, self.upper_ind]  # Upper-diagonal part
+        sense_quad = sense_vals.unsqueeze(1) * rhatsqflat.unsqueeze(2)
+        sense_quad = sense_quad.reshape(-1, self.n_dist * 5)
+        sense_stacked = torch.concatenate([sense_vals, sense_vec, sense_quad], dim=1)
+
+        # Message passing, stack sensitivities to coalesce custom kernel call.
+        # shape (n_atoms, n_nu + 3*n_nu + 5*n_nu, n_feat)
+        env_features_stacked = custom_kernels.envsum(sense_stacked, in_features, pair_first, pair_second)
+        # shape (n_atoms, 9, n_nu, n_feat)
+        env_features_stacked = env_features_stacked.reshape(-1, 9, self.n_dist, self.nf_in)
+
+        # separate to tensor components
+        env_features, env_features_vec, env_features_quad = torch.split(env_features_stacked, [1, 3, 5], dim=1)
+
+        # Scalar stuff.
         env_features = torch.reshape(env_features, (n_atoms_real, self.n_dist * self.nf_in))
         weights_rs = torch.reshape(self.int_weights.permute(0, 2, 1), (self.n_dist * self.nf_in, self.nf_out))
         features_out = torch.mm(env_features, weights_rs)
 
         # Vector part
         # Sensitivity
-        rhats = coord_pairs / dist_pairs.unsqueeze(1)
-        sense_vec = sense_vals.unsqueeze(1) * rhats.unsqueeze(2)
-        sense_vec = sense_vec.reshape(-1, self.n_dist * 3)
         # Weights
-        env_features_vec = custom_kernels.envsum(sense_vec, in_features, pair_first, pair_second)
         env_features_vec = env_features_vec.reshape(n_atoms_real * 3, self.n_dist * self.nf_in)
         features_out_vec = torch.mm(env_features_vec, weights_rs)
         # Norm and scale
@@ -498,16 +530,7 @@ 

Source code for hippynn.layers.hiplayers

 
         # Quadrupole part
         # Sensitivity
-        rhatsquad = rhats.unsqueeze(1) * rhats.unsqueeze(2)
-        rhatsquad = (rhatsquad + rhatsquad.transpose(1, 2)) / 2
-        tr = torch.diagonal(rhatsquad, dim1=1, dim2=2).sum(dim=1) / 3.0  # Add divide by 3 early to save flops
-        tr = tr.unsqueeze(1).unsqueeze(2) * torch.eye(3, dtype=tr.dtype, device=tr.device).unsqueeze(0)
-        rhatsquad = rhatsquad - tr
-        rhatsqflat = rhatsquad.reshape(-1, 9)[:, self.upper_ind]  # Upper-diagonal part
-        sense_quad = sense_vals.unsqueeze(1) * rhatsqflat.unsqueeze(2)
-        sense_quad = sense_quad.reshape(-1, self.n_dist * 5)
         # Weights
-        env_features_quad = custom_kernels.envsum(sense_quad, in_features, pair_first, pair_second)
         env_features_quad = env_features_quad.reshape(n_atoms_real * 5, self.n_dist * self.nf_in)
         features_out_quad = torch.mm(env_features_quad, weights_rs)  ##sum v b
         features_out_quad = features_out_quad.reshape(n_atoms_real, 5, self.nf_out)
@@ -519,6 +542,7 @@ 

Source code for hippynn.layers.hiplayers

         # Scales
         features_out_quad = features_out_quad * self.quadscales.unsqueeze(0)
 
+        # Combine
         features_out_selfpart = self.selfint(in_features)
 
         features_out_total = features_out + features_out_vec + features_out_quad + features_out_selfpart
diff --git a/_modules/hippynn/pretraining.html b/_modules/hippynn/pretraining.html
index 97d8e47c..c8a5d9b4 100644
--- a/_modules/hippynn/pretraining.html
+++ b/_modules/hippynn/pretraining.html
@@ -150,7 +150,7 @@ 

Source code for hippynn.pretraining

         if not eo_layer.weight.data.shape[-1] == eovals.shape[-1]:
             raise ValueError("The shape of the computed E0 values does not match the shape expected by the model.")
         
-        eo_layer.weight.data = eovals.reshape(1,-1)
+        eo_layer.weight.data = eovals.reshape(1, -1)
         print("Computed E0 energies:", eovals)
         eo_layer.weight.data = eovals.expand_as(eo_layer.weight.data)
         eo_layer.weight.requires_grad_(trainable_after)
diff --git a/_modules/hippynn/tools.html b/_modules/hippynn/tools.html
index 76315e82..8b253050 100644
--- a/_modules/hippynn/tools.html
+++ b/_modules/hippynn/tools.html
@@ -243,9 +243,9 @@ 

Source code for hippynn.tools

 
 
+        print_("Learning rate:{:>10.5g}".format(param_group["lr"]))
@@ -343,6 +343,24 @@

Source code for hippynn.tools

 
 
 
+
+[docs] +def recursive_param_count(state_dict, n=0): + for k, v in state_dict.items(): + if isinstance(v, torch.Tensor): + n += v.numel() + elif isinstance(v, dict): + n += recursive_param_count(v) + elif isinstance(v, (list, tuple)): + n += recursive_param_count({i: x for i, x in enumerate(v)}) + elif isinstance(v, (float, int)): + n += 1 + elif v is None: + pass + else: + raise TypeError(f'Unknown type {type(v)=}, value={v}') + return n
+
diff --git a/_modules/index.html b/_modules/index.html index 71e726c0..0e5884d8 100644 --- a/_modules/index.html +++ b/_modules/index.html @@ -80,6 +80,7 @@

All modules for which code is available

  • hippynn.custom_kernels.env_cupy
  • hippynn.custom_kernels.env_numba
  • hippynn.custom_kernels.env_pytorch
  • +
  • hippynn.custom_kernels.env_triton
  • hippynn.custom_kernels.tensor_wrapper
  • hippynn.custom_kernels.test_env_numba
  • hippynn.custom_kernels.utils
  • @@ -91,6 +92,7 @@

    All modules for which code is available

  • hippynn.experiment.controllers
  • hippynn.experiment.device
  • hippynn.experiment.evaluator
  • +
  • hippynn.experiment.lightning_trainer
  • hippynn.experiment.metric_tracker
  • hippynn.experiment.routines
  • hippynn.experiment.serialization
  • diff --git a/_sources/api_documentation/hippynn.experiment.lightning_trainer.rst.txt b/_sources/api_documentation/hippynn.experiment.lightning_trainer.rst.txt new file mode 100644 index 00000000..0ec1714c --- /dev/null +++ b/_sources/api_documentation/hippynn.experiment.lightning_trainer.rst.txt @@ -0,0 +1,7 @@ +hippynn.experiment.lightning\_trainer module +============================================ + +.. automodule:: hippynn.experiment.lightning_trainer + :members: + :undoc-members: + :show-inheritance: diff --git a/_sources/api_documentation/hippynn.experiment.rst.txt b/_sources/api_documentation/hippynn.experiment.rst.txt index 3758f813..73ed4b81 100644 --- a/_sources/api_documentation/hippynn.experiment.rst.txt +++ b/_sources/api_documentation/hippynn.experiment.rst.txt @@ -16,6 +16,7 @@ Submodules hippynn.experiment.controllers hippynn.experiment.device hippynn.experiment.evaluator + hippynn.experiment.lightning_trainer hippynn.experiment.metric_tracker hippynn.experiment.routines hippynn.experiment.serialization diff --git a/_sources/installation.rst.txt b/_sources/installation.rst.txt index 54384e44..4064fea9 100644 --- a/_sources/installation.rst.txt +++ b/_sources/installation.rst.txt @@ -10,16 +10,18 @@ Requirements: * Python_ >= 3.9 * pytorch_ >= 1.9 * numpy_ + Optional Dependencies: * triton_ (recommended, for improved GPU performance) * numba_ (recommended for improved CPU performance) - * cupy_ (Alternative for accelerating GPU performance) - * ASE_ (for usage with ase) + * cupy_ (alternative for accelerating GPU performance) + * ASE_ (for usage with ase and other misc. features) * matplotlib_ (for plotting) * tqdm_ (for progress bars) - * graphviz_ (for viewing model graphs as figures) + * graphviz_ (for visualizing model graphs) * h5py_ (for loading ani-h5 datasets) * pyanitools_ (for loading ani-h5 datasets) + * pytorch-lightning_ (for distributed training) Interfacing codes: * ASE_ @@ -40,7 +42,7 @@ Interfacing codes: .. _ASE: https://wiki.fysik.dtu.dk/ase/ .. _LAMMPS: https://www.lammps.org/ .. _PYSEQM: https://github.com/lanl/PYSEQM - +.. _pytorch-lightning: https://github.com/Lightning-AI/pytorch-lightning Installation Instructions ^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -67,9 +69,6 @@ Clone the hippynn_ repository and navigate into it, e.g.:: .. _hippynn: https://github.com/lanl/hippynn/ -.. note:: - If you wish to do a cpu-only install, you may need to comment - out ``cupy`` from the conda_requirements.txt file. Dependencies using conda ........................ @@ -78,6 +77,10 @@ Install dependencies from conda using recommended channels:: $ conda install -c pytorch -c conda-forge --file conda_requirements.txt +.. note:: + If you wish to do a cpu-only install, you may need to comment + out ``cupy`` from the conda_requirements.txt file. + Dependencies using pip ....................... diff --git a/_sources/user_guide/settings.rst.txt b/_sources/user_guide/settings.rst.txt index d8657de4..c6764206 100644 --- a/_sources/user_guide/settings.rst.txt +++ b/_sources/user_guide/settings.rst.txt @@ -31,7 +31,7 @@ The following settings are available: - Dynamic * - PROGRESS - Progress bars function during training, evaluation, and prediction - - tqdm, none + - tqdm, none, or floating point string specifying default update rate in seconds (default 1). - tqdm - Yes, but assign this to a generator-wrapper such as ``tqdm.tqdm``, or with a python ``None`` to disable. The wrapper must accept ``tqdm`` arguments, although it technically doesn't have to do anything with them. * - DEFAULT_PLOT_FILETYPE diff --git a/api_documentation/hippynn.custom_kernels.env_triton.html b/api_documentation/hippynn.custom_kernels.env_triton.html index 44efd06c..1e11e09f 100644 --- a/api_documentation/hippynn.custom_kernels.env_triton.html +++ b/api_documentation/hippynn.custom_kernels.env_triton.html @@ -59,7 +59,16 @@
  • hippynn.custom_kernels.env_cupy module
  • hippynn.custom_kernels.env_numba module
  • hippynn.custom_kernels.env_pytorch module
  • -
  • hippynn.custom_kernels.env_triton module
  • +
  • hippynn.custom_kernels.env_triton module +
  • hippynn.custom_kernels.fast_convert module
  • hippynn.custom_kernels.tensor_wrapper module
  • hippynn.custom_kernels.test_env_cupy module
  • @@ -116,8 +125,47 @@
    -
    -

    hippynn.custom_kernels.env_triton module

    +
    +

    hippynn.custom_kernels.env_triton module

    +
    +
    +config_pruner(configs, nargs, **kwargs)[source]
    +

    Trims the unnecessary config options based on the sens. and feat. sizes

    +
    + +
    +
    +envsum(sense, features, pfirst, psecond)[source]
    +
    + +
    +
    +envsum_triton(sensitivities, features, pair_first, pair_second, atom_ids, atom_starts, out_env=None)[source]
    +
    + +
    +
    +featsum(env, sense, pfirst, psecond)[source]
    +
    + +
    +
    +featsum_triton(env, sense, pair_first, pair_second, atom2_ids, atom2_starts, out_feat=None)[source]
    +
    + +
    +
    +get_autotune_config()[source]
    +

    Create a list of config options for the kernels +TODO: Need to spend time actually figuring out more reasonable options +targeted for modern GPUs

    +
    + +
    +
    +sensesum(env, features, pair_first, pair_second, out_sense=None)[source]
    +
    +
    diff --git a/api_documentation/hippynn.custom_kernels.html b/api_documentation/hippynn.custom_kernels.html index 9f8c3800..221d25a4 100644 --- a/api_documentation/hippynn.custom_kernels.html +++ b/api_documentation/hippynn.custom_kernels.html @@ -186,7 +186,16 @@

    Submodulessensesum() -
  • hippynn.custom_kernels.env_triton module
  • +
  • hippynn.custom_kernels.env_triton module +
  • hippynn.custom_kernels.fast_convert module diff --git a/api_documentation/hippynn.custom_kernels.test_env_triton.html b/api_documentation/hippynn.custom_kernels.test_env_triton.html index a4076a9c..33c34c9b 100644 --- a/api_documentation/hippynn.custom_kernels.test_env_triton.html +++ b/api_documentation/hippynn.custom_kernels.test_env_triton.html @@ -116,8 +116,8 @@
    -
    -

    hippynn.custom_kernels.test_env_triton module

    +
    +

    hippynn.custom_kernels.test_env_triton module

    diff --git a/api_documentation/hippynn.databases.SNAPJson.html b/api_documentation/hippynn.databases.SNAPJson.html index 6b2bb4c2..025c5d4f 100644 --- a/api_documentation/hippynn.databases.SNAPJson.html +++ b/api_documentation/hippynn.databases.SNAPJson.html @@ -149,6 +149,9 @@
  • allow_unfound – If true, skip checking if the needed inputs and targets are found. This allows setting inputs=None and/or targets=None.

  • auto_split – If true, look for keys like “split_*” to make initial splits from. See write_npz() method.

  • +
  • device – if set, move the dataset to this device after splitting.

  • +
  • dataloader_kwargs – dictionary, passed to pytorch dataloaders in addition to num_workers, pin_memory. +Refer to pytorch documentation for details.

  • quiet – If True, print little or nothing while loading.

  • diff --git a/api_documentation/hippynn.databases.database.html b/api_documentation/hippynn.databases.database.html index 532916f4..627ab977 100644 --- a/api_documentation/hippynn.databases.database.html +++ b/api_documentation/hippynn.databases.database.html @@ -147,12 +147,12 @@

    Base database functionality from dictionary of numpy arrays

    -class Database(arr_dict, inputs, targets, seed, test_size=None, valid_size=None, num_workers=0, pin_memory=True, allow_unfound=False, auto_split=False, device=None, quiet=False)[source]
    +class Database(arr_dict: dict[str, ~torch.Tensor], inputs: list[str], targets: list[str], seed: [<class 'int'>, <class 'numpy.random.mtrand.RandomState'>, <class 'tuple'>], test_size: float | int | None = None, valid_size: float | int | None = None, num_workers: int = 0, pin_memory: bool = True, allow_unfound: bool = False, auto_split: bool = False, device: ~torch.device | None = None, dataloader_kwargs: dict[str, object] | None = None, quiet=False)[source]

    Bases: object

    Class for holding a pytorch dataset, splitting it, generating dataloaders, etc.”

    -__init__(arr_dict, inputs, targets, seed, test_size=None, valid_size=None, num_workers=0, pin_memory=True, allow_unfound=False, auto_split=False, device=None, quiet=False)[source]
    +__init__(arr_dict: dict[str, ~torch.Tensor], inputs: list[str], targets: list[str], seed: [<class 'int'>, <class 'numpy.random.mtrand.RandomState'>, <class 'tuple'>], test_size: float | int | None = None, valid_size: float | int | None = None, num_workers: int = 0, pin_memory: bool = True, allow_unfound: bool = False, auto_split: bool = False, device: ~torch.device | None = None, dataloader_kwargs: dict[str, object] | None = None, quiet=False)[source]
    Parameters:
      @@ -169,6 +169,9 @@
    • allow_unfound – If true, skip checking if the needed inputs and targets are found. This allows setting inputs=None and/or targets=None.

    • auto_split – If true, look for keys like “split_*” to make initial splits from. See write_npz() method.

    • +
    • device – if set, move the dataset to this device after splitting.

    • +
    • dataloader_kwargs – dictionary, passed to pytorch dataloaders in addition to num_workers, pin_memory. +Refer to pytorch documentation for details.

    • quiet – If True, print little or nothing while loading.

    @@ -379,7 +382,7 @@
    -write_npz(file: str, record_split_masks: bool = True, overwrite: bool = False, split_prefix=None, return_only=False)[source]
    +write_npz(file: str, record_split_masks: bool = True, compressed: bool = True, overwrite: bool = False, split_prefix=None, return_only=False)[source]
    Parameters:
      diff --git a/api_documentation/hippynn.databases.html b/api_documentation/hippynn.databases.html index ad2946bd..2fd00500 100644 --- a/api_documentation/hippynn.databases.html +++ b/api_documentation/hippynn.databases.html @@ -195,6 +195,9 @@
    • allow_unfound – If true, skip checking if the needed inputs and targets are found. This allows setting inputs=None and/or targets=None.

    • auto_split – If true, look for keys like “split_*” to make initial splits from. See write_npz() method.

    • +
    • device – if set, move the dataset to this device after splitting.

    • +
    • dataloader_kwargs – dictionary, passed to pytorch dataloaders in addition to num_workers, pin_memory. +Refer to pytorch documentation for details.

    • quiet – If True, print little or nothing while loading.

    @@ -226,12 +229,12 @@
    -class Database(arr_dict, inputs, targets, seed, test_size=None, valid_size=None, num_workers=0, pin_memory=True, allow_unfound=False, auto_split=False, device=None, quiet=False)[source]
    +class Database(arr_dict: dict[str, ~torch.Tensor], inputs: list[str], targets: list[str], seed: [<class 'int'>, <class 'numpy.random.mtrand.RandomState'>, <class 'tuple'>], test_size: float | int | None = None, valid_size: float | int | None = None, num_workers: int = 0, pin_memory: bool = True, allow_unfound: bool = False, auto_split: bool = False, device: ~torch.device | None = None, dataloader_kwargs: dict[str, object] | None = None, quiet=False)[source]

    Bases: object

    Class for holding a pytorch dataset, splitting it, generating dataloaders, etc.”

    -__init__(arr_dict, inputs, targets, seed, test_size=None, valid_size=None, num_workers=0, pin_memory=True, allow_unfound=False, auto_split=False, device=None, quiet=False)[source]
    +__init__(arr_dict: dict[str, ~torch.Tensor], inputs: list[str], targets: list[str], seed: [<class 'int'>, <class 'numpy.random.mtrand.RandomState'>, <class 'tuple'>], test_size: float | int | None = None, valid_size: float | int | None = None, num_workers: int = 0, pin_memory: bool = True, allow_unfound: bool = False, auto_split: bool = False, device: ~torch.device | None = None, dataloader_kwargs: dict[str, object] | None = None, quiet=False)[source]
    Parameters:
      @@ -248,6 +251,9 @@
    • allow_unfound – If true, skip checking if the needed inputs and targets are found. This allows setting inputs=None and/or targets=None.

    • auto_split – If true, look for keys like “split_*” to make initial splits from. See write_npz() method.

    • +
    • device – if set, move the dataset to this device after splitting.

    • +
    • dataloader_kwargs – dictionary, passed to pytorch dataloaders in addition to num_workers, pin_memory. +Refer to pytorch documentation for details.

    • quiet – If True, print little or nothing while loading.

    @@ -458,7 +464,7 @@
    -write_npz(file: str, record_split_masks: bool = True, overwrite: bool = False, split_prefix=None, return_only=False)[source]
    +write_npz(file: str, record_split_masks: bool = True, compressed: bool = True, overwrite: bool = False, split_prefix=None, return_only=False)[source]
    Parameters:
      @@ -516,6 +522,9 @@
    • allow_unfound – If true, skip checking if the needed inputs and targets are found. This allows setting inputs=None and/or targets=None.

    • auto_split – If true, look for keys like “split_*” to make initial splits from. See write_npz() method.

    • +
    • device – if set, move the dataset to this device after splitting.

    • +
    • dataloader_kwargs – dictionary, passed to pytorch dataloaders in addition to num_workers, pin_memory. +Refer to pytorch documentation for details.

    • quiet – If True, print little or nothing while loading.

    @@ -557,6 +566,9 @@
  • allow_unfound – If true, skip checking if the needed inputs and targets are found. This allows setting inputs=None and/or targets=None.

  • auto_split – If true, look for keys like “split_*” to make initial splits from. See write_npz() method.

  • +
  • device – if set, move the dataset to this device after splitting.

  • +
  • dataloader_kwargs – dictionary, passed to pytorch dataloaders in addition to num_workers, pin_memory. +Refer to pytorch documentation for details.

  • quiet – If True, print little or nothing while loading.

  • diff --git a/api_documentation/hippynn.databases.ondisk.html b/api_documentation/hippynn.databases.ondisk.html index f822a197..3e09d4cf 100644 --- a/api_documentation/hippynn.databases.ondisk.html +++ b/api_documentation/hippynn.databases.ondisk.html @@ -169,6 +169,9 @@
  • allow_unfound – If true, skip checking if the needed inputs and targets are found. This allows setting inputs=None and/or targets=None.

  • auto_split – If true, look for keys like “split_*” to make initial splits from. See write_npz() method.

  • +
  • device – if set, move the dataset to this device after splitting.

  • +
  • dataloader_kwargs – dictionary, passed to pytorch dataloaders in addition to num_workers, pin_memory. +Refer to pytorch documentation for details.

  • quiet – If True, print little or nothing while loading.

  • @@ -210,6 +213,9 @@
  • allow_unfound – If true, skip checking if the needed inputs and targets are found. This allows setting inputs=None and/or targets=None.

  • auto_split – If true, look for keys like “split_*” to make initial splits from. See write_npz() method.

  • +
  • device – if set, move the dataset to this device after splitting.

  • +
  • dataloader_kwargs – dictionary, passed to pytorch dataloaders in addition to num_workers, pin_memory. +Refer to pytorch documentation for details.

  • quiet – If True, print little or nothing while loading.

  • diff --git a/api_documentation/hippynn.experiment.assembly.html b/api_documentation/hippynn.experiment.assembly.html index 1930f8e0..a6f72f8a 100644 --- a/api_documentation/hippynn.experiment.assembly.html +++ b/api_documentation/hippynn.experiment.assembly.html @@ -79,6 +79,7 @@
  • hippynn.experiment.controllers module
  • hippynn.experiment.device module
  • hippynn.experiment.evaluator module
  • +
  • hippynn.experiment.lightning_trainer module
  • hippynn.experiment.metric_tracker module
  • hippynn.experiment.routines module
  • hippynn.experiment.serialization module
  • diff --git a/api_documentation/hippynn.experiment.controllers.html b/api_documentation/hippynn.experiment.controllers.html index a4eecb15..f276e9b6 100644 --- a/api_documentation/hippynn.experiment.controllers.html +++ b/api_documentation/hippynn.experiment.controllers.html @@ -81,6 +81,7 @@
  • RaiseBatchSizeOnPlateau
  • -class RaiseBatchSizeOnPlateau(optimizer, max_batch_size, factor=0.5, patience=10, threshold=0.0001, threshold_mode='rel', verbose=True, controller=None)[source]
    -

    Bases: object

    +class RaiseBatchSizeOnPlateau(optimizer, max_batch_size, factor=0.5, patience=10, threshold=0.0001, threshold_mode='rel', verbose=None, controller=None)[source] +

    Bases: ReduceLROnPlateau

    Learning rate scheduler compatible with pytorch schedulers.

    +

    Note: The “VERBOSE” Parameter has been deprecated and no longer does anything.

    This roughly implements the scheme outlined in the following paper:

    "Don't Decay the Learning Rate, Increase the Batch Size"
     Samuel L. Smith et al., 2018.
    @@ -253,12 +256,39 @@
     
    -__init__(optimizer, max_batch_size, factor=0.5, patience=10, threshold=0.0001, threshold_mode='rel', verbose=True, controller=None)[source]
    -
    +__init__(optimizer, max_batch_size, factor=0.5, patience=10, threshold=0.0001, threshold_mode='rel', verbose=None, controller=None)[source] +
    +
    Parameters:
    +
      +
    • optimizer

    • +
    • max_batch_size

    • +
    • factor

    • +
    • patience

    • +
    • threshold

    • +
    • threshold_mode

    • +
    • verbose

    • +
    • controller

    • +
    +
    +
    +
    load_state_dict(state_dict)[source]
    +

    Loads the schedulers state.

    +
    +
    Args:
    +
    state_dict (dict): scheduler state. Should be an object returned

    from a call to state_dict().

    +
    +
    +
    +
    +
    + +
    +
    +property optimizer
    @@ -269,7 +299,10 @@
    state_dict()[source]
    -
    +

    Returns the state of the scheduler as a dict.

    +

    It contains an entry for every variable in self.__dict__ which +is not the optimizer.

    +
    diff --git a/api_documentation/hippynn.experiment.device.html b/api_documentation/hippynn.experiment.device.html index a8c128ce..f2d90aec 100644 --- a/api_documentation/hippynn.experiment.device.html +++ b/api_documentation/hippynn.experiment.device.html @@ -69,6 +69,7 @@
  • hippynn.experiment.evaluator module
  • +
  • hippynn.experiment.lightning_trainer module
  • hippynn.experiment.metric_tracker module
  • hippynn.experiment.routines module
  • hippynn.experiment.serialization module
  • diff --git a/api_documentation/hippynn.experiment.evaluator.html b/api_documentation/hippynn.experiment.evaluator.html index 20deabe0..5bb35971 100644 --- a/api_documentation/hippynn.experiment.evaluator.html +++ b/api_documentation/hippynn.experiment.evaluator.html @@ -22,7 +22,7 @@ - + @@ -74,6 +74,7 @@ +
  • hippynn.experiment.lightning_trainer module
  • hippynn.experiment.metric_tracker module
  • hippynn.experiment.routines module
  • hippynn.experiment.serialization module
  • @@ -119,7 +120,7 @@

    @@ -178,7 +179,7 @@