diff --git a/JGNN/neural_graph.jggn b/JGNN/neural_graph.jggn new file mode 100644 index 00000000..540bc2c1 --- /dev/null +++ b/JGNN/neural_graph.jggn @@ -0,0 +1,49 @@ +mklab.JGNN.adhoc.parsers.LayeredBuilder +features = config: 2.0 +hidden = config: 8.0 +reg = config: 0.005 +classes = config: 2.0 +reduced = config: 6.0 +2hidden = config: 16.0 +hiddenReduced = config: 48.0 +A = var: null +h0 = var: null +_tmp11 = param DenseMatrix (2hidden 16, hidden 8): [-0.19328476877863923,-0.6816842687001173,0.33084330660685063,0.29012383747022974,-0.008856823759693731,0.3243209298674903,0.03630835867502023,-0.36893022633353045,0.3852823246760141,0.1675140168304929,0.34749646804722034,-0.043510359020801116,-0.5480261005662305,-0.36506518491627604,0.4618722893431745,-0.31670662910836095,-0.06766959139063297,3.964965286227935E-4,0.42952042800161827,-0.6955929308960008,-0.6286510760021554,-0.0965623432938279,-0.18299829835981102,-0.05956609007597157,-0.2920302998132937,-0.2547563636041248,0.15096554826682856,0.3775004405962659,-0.8995334446683099,0.06333771747725044,0.65460552575879,-0.3186670866531967,-0.5309119132179063,0.394784224886659,0.29509546796445657,-0.4690491776981738,-0.10572405664194967,0.27624178602476357,0.019287597474623407,-0.08221736520461895,0.3752314288377563,-0.14388855480059862,0.5031819562803397,-0.39231949747503947,0.030226728241452908,0.539230918192968,-0.1866028170582539,-0.054814542702866335,-0.037001880635509805,-0.05866626266145604,0.6570208369054049,0.4037472351461765,-0.7434327464304322,-0.1800452245270526,0.40850049372980474,0.22084528373229415,0.01669375687488634,-0.01769463510961822,0.48369486751642693,-0.4513516083571352,-0.18643307652802016,0.47531322785792024,-0.15418322605338194,0.08869303689826054,0.6496841665481397,0.12057361296404132,-0.10087202695777087,0.20077395712411783,0.12298342734788265,0.4463359109741554,0.595268214487788,0.3198579540060096,0.15003772422143247,0.3036177944924227,0.39176025176862367,0.4385176746516399,0.29267786236453575,-0.538075342791701,0.18692588008664582,-0.22874264781324213,-0.2956689694666771,0.08507352063026218,0.16405172192265624,-0.16627517419517734,0.7367038620405035,0.23692065568506798,-0.39340620243769026,-0.1218118529198394,0.34270271396790364,0.10611754538241865,0.23511719104140957,0.06257205623691434,-0.026719449424170108,0.13262598343728552,0.3321072492887023,-0.322213838336707,-0.09058229994288335,-0.03864904977375033,0.18830357531855324,0.035130801296273106,-0.40433428753295697,0.34087491252629254,0.8426791513172762,0.07823534114592788,0.017292176905372218,-0.07825012073294113,-0.551117338968398,0.5990794942026233,-0.3498034727233782,-0.15620124985299463,-0.1313686261579107,-0.17445929115062758,0.05423176395512194,0.3305972373060727,-0.07433345988400326,0.02211377023555965,0.5169316243904939,-0.35626175262079174,0.2605377618616539,0.040877884464697034,0.6011922447538958,0.10590649114094869,0.04866151867585637,1.0498653811003313,-0.4831934782954236,0.2155992503280004,-0.4449982638493458,-0.15205818353654899] +_tmp12 = param Tensor (hidden 8): [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0] +_tmp17 = param Tensor (hidden 8): [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0] +_tmp16 = param DenseMatrix (2hidden 16, hidden 8): [0.5021034187023461,0.13868567394933753,-0.4319463012253357,-0.01957776716822693,-0.6094145749292609,0.1439419292273619,-0.10279569620728086,-0.24391336506490913,-0.07060120291020784,0.365303228749883,0.46701090217314395,-0.07945034851001198,0.18256005736919872,0.523490149473091,0.3783930271864008,-0.5348058052156716,0.248710923254586,-0.18060630411319933,-0.9064084134026116,-0.5047861771258582,-0.21041070643734053,-0.2508198800475379,-0.19029786209348837,-0.34694159663184004,0.02890813945657403,0.11926584949928104,0.008029066336753142,-0.30160347629852446,0.4142657749089172,-0.10276450006376554,0.21668328634092998,0.7794994419864524,0.31912230822442206,-0.7896441157875099,0.0669323910449409,0.005079087575026865,0.3871876711708779,-0.027795613673683106,0.5195156321578849,0.1451244649129374,-0.1234044285044108,-0.3910978872449653,0.19433235371428767,-0.484786863765534,-0.06500045233351805,-0.13535973107007787,-0.001400944670307145,0.14315247913721696,-0.44489526404547725,-0.7113623734549364,-0.3537344028937658,-0.2826613161956493,0.38693511145728415,-0.5540629852227372,-0.26564026626370785,-0.7528758353657103,-0.5030481946916305,0.7361802029086209,0.048380119415574255,0.19193865286784526,-0.18572427292244453,-0.1070739962101043,-0.4667238425266632,0.1724481017787257,0.07907853007683265,0.07794419563775512,-0.2258137337359118,0.7440959006326245,-0.684988546540687,0.17920646820148134,0.46240699117307765,0.4320721338608193,-0.37346377979218376,-0.4607102416897734,-0.5562339107003924,-1.1335825147833678,-0.6780782600923088,0.3572746140653733,-0.41292103314568673,-0.31905109650720387,-0.25370387856853094,-0.480925487249198,-0.3151921264430039,0.07776796917965208,-0.1788986631998994,0.08475066278957165,-0.5237351973530295,0.4193834593785406,0.03681254346412109,-0.584427852645633,-0.02896556963429372,-0.1299053244734533,-0.2218704270185149,0.24571193054996324,0.42048425749103463,-0.8880484715017992,0.31700958719137495,-0.4581902441160126,-0.30034222408803624,-0.7713140669541653,0.1906079322091254,0.5614853908408053,-0.12311847026365601,0.094618305065478,0.16756886278939304,0.5190453004025772,0.1461756848410729,0.10340340085403356,0.009543912257454283,-0.4945736675559478,0.18117986153395202,0.4416916735062306,0.2303877229303233,0.22360404940516052,-0.6112281541522248,-0.3049268148884489,-0.21320808341784314,0.23185142098280917,0.4502900379269799,-0.165791378027787,0.7864019252894919,0.2813364624331278,0.9474109407516412,0.12430889772546948,-0.2172150595120388,0.4876508580337231,0.344365758524931,0.36338942310374867] +_tmp20 = param DenseMatrix (hidden 8, hidden 8): [0.6814404680132916,0.25361885765312314,0.10256163018756725,0.1922851650348111,-0.05410077006150943,-0.38549110279625004,0.5004202987313823,-0.2324453468875155,-0.22099505283361742,0.06551162436227614,0.5271888950214714,0.22805732984591637,0.42595558449507714,-0.2595469927668719,0.41457441383114,-0.40693167679055886,0.39803040041408844,0.3440959433142813,-0.5753750505438064,0.06237296287559791,-0.4211115226435094,-0.1763906243200833,1.1985422471982727,-0.07676527261740763,0.2326178151531441,0.35704961284054426,-0.26729485975443884,0.4095283664505398,0.18168483029293744,-0.2932381742502261,0.8414330706272171,0.37977048763315235,0.5062335030466366,0.9869523173389951,1.6838631648078104,0.1263031934267131,-0.5478652955503843,-3.035840047741802E-4,-0.3174070977338061,0.7191699495163361,-0.07943020358021734,0.21984277828430565,0.8033588782047216,-0.1861673333583821,0.5596593401061914,-0.35031592936519707,-0.37717583187856096,0.23602515813879557,0.0449456664499197,-0.47412881293932263,0.21589515054272587,-0.20568191860829724,0.2901169807397064,0.4150953875668683,0.03210280126721354,-0.7381724464055942,0.08744218760930367,1.657249410635567,0.38200285347323787,0.41650869362747445,0.12939056231065268,0.40376690602764725,-0.047337732760709926,0.6966554415998744] +_tmp21 = param Tensor (hidden 8): [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0] +_tmp3 = param Tensor (hidden 8): [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0] +_tmp2 = param DenseMatrix (features 2, hidden 8): [0.0962596204088395,-0.028884691451251224,-0.028116488038815862,0.3577563699462411,-0.1915495424716903,-0.5879573411284648,0.16687000433532928,0.18406857056345563,-0.5892489716766497,0.9247104404685634,-0.06459238390322171,-1.1421568977936178,-0.722566667961998,-1.3796258912305088,-0.06280695024528939,1.7517702074586412] +_tmp23 = param DenseMatrix (hiddenReduced 48, classes 2): [0.40938433617236525,-0.4987199441662107,0.26203992378101376,-0.43133379215734263,0.6709105283816391,-0.9814661346740309,0.539095205750881,-0.4207697065158933,0.0240229334618446,-0.14223548752189802,0.4438842753898452,-0.49120390072985726,0.17395173314865078,0.06471401007521176,0.4478543276829584,0.43432759563978957,0.3447420155314258,-1.1253578954761723,0.4993227096710841,0.30042014773726416,0.6795526379343243,-0.5195903531075249,0.4621000372201286,-0.7689859292393859,0.9106708726562932,-0.6086682072036363,-0.49443734786718163,-2.0015702142041314,0.2759331578478036,-0.9644772192516243,-0.14250917516719866,-5.04944040825525,1.8578946969831234,-2.689873422142453,17.819588902507228,2.3524952524944034,0.607415470001217,-0.005342721841655263,0.2345666718277672,-0.6435856594935998,0.09465516271962526,-1.028924483552736,0.5494581031684685,0.004150240845003633,0.09245869079258418,-1.159928537758163,0.5395232430817755,-0.8876905496583974,-0.43672870130761726,-0.6001793288795612,-0.2209758164529676,0.3771985686838059,-0.5264231473231888,1.2086668405130667,-0.3677004028561642,-0.23961520350564466,-0.1312633323701847,0.20028985803474514,-0.5997756663497473,1.1861726582546483,-0.5171240051903958,-0.36876941825191734,-0.08069918314199832,-0.6008323309697636,0.2923273123036548,1.4251798900771944,-0.3047689840101401,-0.11336172298411515,-0.5004652657880019,-0.25323032177253096,-0.04182412745906058,0.7914127746669496,-0.7185649111348198,0.4124581236471469,0.8006830216814292,1.4727565356504557,-0.98152228243357,0.6496030752864841,0.2086967713555983,5.221403415918512,-2.005888671728391,2.2603888220693245,-17.575851034929514,-2.500026725809498,-0.24434030852837782,-0.33232526683552893,0.3851370120960596,1.4858460025078009,-0.421678210018014,0.7016615436437713,-0.578971814276199,0.029161238557184666,-0.06956312718637947,1.0065065640325745,-0.5451395936226792,1.0762157054167099] +_tmp5 = param DenseMatrix (hidden 8, hidden 8): [-0.30954935594367816,0.12116378595232227,-0.306386786291067,-0.5220864721581343,-0.003000295537359033,-0.37628978694357235,-0.2695812643245417,0.6700734618145359,0.20645751601752704,-0.05260439493400403,0.17754381848591488,-0.17735111718384539,0.167779299276608,0.27389750680988467,0.23764308545771937,1.6708105311783883,0.06728854514002967,-0.44839216787805547,0.13632607512869982,0.37582277585252005,-0.28216295775529515,-0.5678759903551267,-0.35640357469197825,0.3560222333683681,-0.06393612839774328,-0.43344284383245607,-0.3986854818366697,0.26732202379329006,0.34858445673548166,0.01968688359473008,0.7123181505558769,-0.6411211857602117,-0.271134106235963,0.2491391197457063,0.0939090724621519,-0.7099186482668389,0.4388952099959059,0.8554823355249708,-1.166165039369413,1.1988479095710083,-0.2320232541561487,-0.08166507448918221,-0.41733625408668373,-0.6475038987362935,-0.4785141909120537,-0.6513421097528718,0.7196243595671512,0.0627929089576402,0.08109030285230427,-2.3186007809859356E-4,0.06453881213468693,0.7580057257515331,-0.12080130803301743,-0.17046938190024785,0.28216868068309153,0.005635712488808225,-0.45095495159758453,0.13584676111699664,-0.73411011182003,-0.22715114851512125,-0.5224650791751124,0.8293987840564809,0.021505674769915967,-0.2903500417602574] +_tmp6 = param Tensor (hidden 8): [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0] +edgeSrc = from A +edgeDst = to A +_tmp1 = h0 @ _tmp2 +_tmp0 = _tmp1 + _tmp3 +h1 = relu _tmp0 +_tmp4 = h1 @ _tmp5 +h2 = _tmp4 + _tmp6 +_tmp7 = h2 [ edgeSrc ] +_tmp8 = h2 [ edgeDst ] +message2 = _tmp7 | _tmp8 +_tmp10 = message2 @ _tmp11 +_tmp9 = _tmp10 + _tmp12 +transformed2 = relu _tmp9 +received2 = reduce ( transformed2 , A ) +_tmp14 = received2 | h2 +_tmp15 = _tmp14 @ _tmp16 +_tmp13 = _tmp15 + _tmp17 +i2 = relu _tmp13 +_tmp19 = i2 @ _tmp20 +_tmp18 = _tmp19 + _tmp21 +h3 = relu _tmp18 +z3 = sort ( h3 , reduced ) +_tmp22 = h3 [ z3 ] +h4 = reshape ( _tmp22 , 1 , hiddenReduced ) +h5 = h4 @ _tmp23 +h6 = softmax ( h5 , row ) + +return h6 diff --git a/JGNN/src/examples/graphClassification/MessageSortPooling.java b/JGNN/src/examples/graphClassification/MessageSortPooling.java new file mode 100644 index 00000000..eadd3af2 --- /dev/null +++ b/JGNN/src/examples/graphClassification/MessageSortPooling.java @@ -0,0 +1,99 @@ +package graphClassification; + +import java.util.Arrays; + +import mklab.JGNN.adhoc.ModelBuilder; +import mklab.JGNN.adhoc.parsers.LayeredBuilder; +import mklab.JGNN.core.Matrix; +import mklab.JGNN.core.Tensor; +import mklab.JGNN.core.ThreadPool; +import mklab.JGNN.nn.Loss; +import mklab.JGNN.nn.Model; +import mklab.JGNN.nn.initializers.XavierNormal; +import mklab.JGNN.nn.loss.CategoricalCrossEntropy; +import mklab.JGNN.nn.optimizers.Adam; +import mklab.JGNN.nn.optimizers.BatchOptimizer; + +/** + * + * @author github.com/gavalian + * @author Emmanouil Krasanakis + */ +public class MessageSortPooling { + + public static void main(String[] args){ + long reduced = 5; // input graphs need to have at least that many nodes, lower values decrease accuracy + long hidden = 8; // since this library does not use GPU parallelization, many latent dims reduce speed + + ModelBuilder builder = new LayeredBuilder() + .var("A") + .config("features", 1) + .config("classes", 2) + //.config("reduced", reduced) + .config("hidden", hidden) + .config("2hidden", 2*hidden) + .config("reg", 0.005) + .operation("edgeSrc = from(A)") + .operation("edgeDst = to(A)") + .layer("h{l+1}=relu(h{l}@matrix(features, hidden, reg)+vector(hidden))") + .layer("h{l+1}=h{l}@matrix(hidden, hidden, reg)+vector(hidden)") + + // message passing layer (make it as complex as needed) + .operation("message{l}=h{l}[edgeSrc] | h{l}[edgeDst]") + .operation("transformed{l}=relu(message{l}@matrix(2hidden, hidden, reg)+vector(hidden))") + .operation("received{l}=reduce(transformed{l}, A)") + .operation("i{l}=relu((received{l} | h{l})@matrix(2hidden, hidden, reg)+vector(hidden))") + .layer("h{l+1}=relu(i{l}@matrix(hidden, hidden, reg)+vector(hidden))") + + // this would be the sort pooling + /*.config("hiddenReduced", hidden*reduced) // reduced * (previous layer's output size) + .operation("z{l}=sort(h{l}, reduced)") // currently, the parser fails to understand full expressions within next step's gather, so we need to create this intermediate variable + .layer("h{l+1}=reshape(h{l}[z{l}], 1, hiddenReduced)") // + .layer("h{l+1}=h{l}@matrix(hiddenReduced, classes)") + .layer("h{l+1}=softmax(h{l}, row)")*/ + + // the following two layers implement the sum pooling + .layer("h{l+1}=sum(h{l}@matrix(hidden, classes)+vector(classes), row)") + .layer("h{l+1}=softmax(h{l}, row)") + + .out("h{l}"); + + TrajectoryData dtrain = new TrajectoryData(800); + TrajectoryData dtest = new TrajectoryData(200); + + Model model = builder.getModel().init(new XavierNormal()); + BatchOptimizer optimizer = new BatchOptimizer(new Adam(0.01)); + Loss loss = new CategoricalCrossEntropy(); + for(int epoch=0; epoch<600; epoch++) { + // gradient update over all graphs + for(int graphId=0; graphIdthis Matrix instance. + * @see #symmetricNormalization() + */ + public Matrix setToASymmetricNormalization() { + HashMap outDegrees = new HashMap(); + HashMap inDegrees = new HashMap(); + for(Entry element : getNonZeroEntries()) { + long row = element.getKey(); + long col = element.getValue(); + double value = get(row, col); + outDegrees.put(row, outDegrees.getOrDefault(row, 0.)+value); + inDegrees.put(col, inDegrees.getOrDefault(col, 0.)+value); + } + for(Entry element : getNonZeroEntries()) { + long row = element.getKey(); + long col = element.getValue(); + double div = inDegrees.get(col); + if(div!=0) + put(row, col, get(row, col)/div); + } + return this; + } + /** * Retrieves either the given row or column as a trensor. * @param index The dimension index to access. diff --git a/JGNN/src/main/java/mklab/JGNN/core/Memory.java b/JGNN/src/main/java/mklab/JGNN/core/Memory.java index 6c2c01de..5e97aa17 100644 --- a/JGNN/src/main/java/mklab/JGNN/core/Memory.java +++ b/JGNN/src/main/java/mklab/JGNN/core/Memory.java @@ -8,7 +8,7 @@ import java.util.WeakHashMap; /** - * A memory management systems for thread-safe allocation and release of arrays of doubles. + * A memory management system for thread-safe allocation and release of arrays of doubles. * Soft references to allocated arrays kept so that released ones can be reused by future * allocation calls without explicitly initializing memory. * diff --git a/JGNN/src/main/java/mklab/JGNN/core/Tensor.java b/JGNN/src/main/java/mklab/JGNN/core/Tensor.java index a1a07810..05448413 100644 --- a/JGNN/src/main/java/mklab/JGNN/core/Tensor.java +++ b/JGNN/src/main/java/mklab/JGNN/core/Tensor.java @@ -286,8 +286,9 @@ public final Tensor add(Tensor tensor) { if(density() iterator() { diff --git a/JGNN/src/main/java/mklab/JGNN/nn/loss/CategoricalCrossEntropy.java b/JGNN/src/main/java/mklab/JGNN/nn/loss/CategoricalCrossEntropy.java index 37e7b7f3..5c01af3d 100644 --- a/JGNN/src/main/java/mklab/JGNN/nn/loss/CategoricalCrossEntropy.java +++ b/JGNN/src/main/java/mklab/JGNN/nn/loss/CategoricalCrossEntropy.java @@ -1,5 +1,6 @@ package mklab.JGNN.nn.loss; +import mklab.JGNN.core.Matrix; import mklab.JGNN.core.Tensor; import mklab.JGNN.nn.Loss; @@ -10,6 +11,7 @@ */ public class CategoricalCrossEntropy extends Loss { private double epsilon; + private boolean meanReduction; /** * Initializes categorical cross entropy with 1.E-12 epsilon value. @@ -28,13 +30,30 @@ public CategoricalCrossEntropy(double epsilon) { this.epsilon = epsilon; } + /** + * Sets the reduction mechanism of categorical cross entropy. + * This can be either a sum or a mean across the categorical cross entropy of all data samples. + * @param meanReduction true to perform mean reduction, false (default) for sum reduction. + * @return this CategoricalCrossEntropy object. + */ + public CategoricalCrossEntropy setMeanReduction(boolean meanReduction) { + this.meanReduction = meanReduction; + return this; + } + @Override public double evaluate(Tensor output, Tensor desired) { - return -output.add(epsilon).selfLog().selfMultiply(desired).sum();// / output.cast(Matrix.class).getRows(); + double ret = -output.add(epsilon).selfLog().selfMultiply(desired).sum(); + if(meanReduction) + ret /= output.cast(Matrix.class).getRows(); + return ret; } @Override public Tensor derivative(Tensor output, Tensor desired) { - return desired.multiply(output.add(epsilon).selfInverse()).negative();//.selfMultiply(-1. / output.cast(Matrix.class).getRows()); + Tensor ret = desired.multiply(output.add(epsilon).selfInverse()).negative(); + if(meanReduction) + ret.selfMultiply(1. / output.cast(Matrix.class).getRows()); + return ret; } } diff --git a/JGNN/target/classes/mklab/JGNN/core/Matrix.class b/JGNN/target/classes/mklab/JGNN/core/Matrix.class index 75dcc62a..a8149645 100644 Binary files a/JGNN/target/classes/mklab/JGNN/core/Matrix.class and b/JGNN/target/classes/mklab/JGNN/core/Matrix.class differ diff --git a/JGNN/target/classes/mklab/JGNN/core/Tensor.class b/JGNN/target/classes/mklab/JGNN/core/Tensor.class index 0de6028d..bd6caa31 100644 Binary files a/JGNN/target/classes/mklab/JGNN/core/Tensor.class and b/JGNN/target/classes/mklab/JGNN/core/Tensor.class differ diff --git a/tutorials/Debugging.md b/tutorials/Debugging.md index c1fe3b1d..9d108807 100644 --- a/tutorials/Debugging.md +++ b/tutorials/Debugging.md @@ -36,7 +36,7 @@ method for model build functional flows that prints all the parsed expressions and intermediate expression in the system console, and b) a. `.getExecutionGraphDot()` that returns a String holding the execution graph in *.dot* format for visualization with external tools, such -as (GraphViz](https://dreampuf.github.io/GraphvizOnline). +as [GraphViz](https://dreampuf.github.io/GraphvizOnline). A second error-checking procedure consists of checking for model operations that do not diff --git a/tutorials/GNN.md b/tutorials/GNN.md index 62534054..1e9ee8f1 100644 --- a/tutorials/GNN.md +++ b/tutorials/GNN.md @@ -79,7 +79,7 @@ on the adjacency matrix values on each layer per: .layer("h{l+1}=dropout(A,0.5) @ h{l}") ``` -Recent areas of heterogenous graph research also explicitly use the graph laplacian, +Recent areas of heterogenous graph research also explicitly use the graph Laplacian, which you can insert into the architecture as a normal constant per `.constant("L", adjacency.negative().cast(Matrix.class).setMainDiagonal(1))`. Even more complex concepts can be modelled with edge attention that gathers and perform the dot product of edge nodes to provide new edge weights, exponentiating diff --git a/tutorials/Learning.md b/tutorials/Learning.md index fde5f1a4..7dd6aa19 100644 --- a/tutorials/Learning.md +++ b/tutorials/Learning.md @@ -80,7 +80,7 @@ is presented later in in the [debugging](Debugging.md) tutorial. ## Training To train the model, we set up 50-25-25 training-validation-test data slices. These basically handle shuffled sample identifiers. You can use integers instead of -doubles in the `range` method to reference a fixed fixed instead of fractional slice sizes. +doubles in the `range` method to reference a fixed number of samples instead of fractional slice sizes. ```java Slice samples = dataset.samples().getSlice().shuffle(); // or samples = new Slice(0, labels.getRows()).shuffle(); diff --git a/tutorials/NN.md b/tutorials/NN.md index d15fd00c..465effac 100644 --- a/tutorials/NN.md +++ b/tutorials/NN.md @@ -32,13 +32,15 @@ ModelBuilder modelBuilder = new LayeredBuilder("h0") ## Deep architectures Now that we have explained how simple layers work, let's look at two more advanced `LayeredBuilder` methods pivotal to many deep neural networks. -The first is `.layerRepeat(String, int)`), which just repeats +The first is `.layerRepeat(String, int)`, which just repeats the layer expression a set number of times without breaking the -functional model definition pipeline. The second is `.concat(int)`. Concatenation -is also possible with normal parsing with the `|` operation, but this performs it over any -number of layers. +functional model definition pipeline. The second is `.concat(int)`, +which concatenates horizontally concatenates a number of top layers. Concatenation +is also possible in symbolic parsing through the `|` operation, +but calling the method easily scales it over a large number of layers +(e.g., across several graph convolutional layers). -We now make a more advanved model: +We now make a more advanved model using these methods: ```java ModelBuilder modelBuilder = new LayeredBuilder() @@ -54,7 +56,7 @@ ModelBuilder modelBuilder = new LayeredBuilder() ``` ## Writing operations -This is a good point to we present symbols you can use to define operation expressions. +This is a good point to present symbols you can use within expressions. Unless otherwise specified, you can replace x and y with any expression. Sometimes, y needs to be a constant defined either by presenting a number, calling `ModelBuilder.config(y, double)`, or calling `ModelBuilder.constant(y, double)` to @@ -75,7 +77,7 @@ set the numbers as hyperparameters. | tanh(x) | Function | Apply a tanh activation on each tensor element. | | sigmoid(x) | Function | Apply a sigmoid activation on each tensor element. | | dropout(x, y) | Function | Apply training dropout on tensor x with constant dropout rate y. | -| lrely(x, y) | Function | Leaky relu on tensor x with constant negative slope y. | +| lrelu(x, y) | Function | Leaky relu on tensor x with constant negative slope y. | | prelu(x) | Function | Leaky relu on tensor x with learnanble negative slope. | | softmax(x, y) | Function | Apply y-wide softmax on x, where y is either row or col.| | sum(x, y) | Function | Apply y-wide sum reduction on x, where y is either row or col.| @@ -83,7 +85,7 @@ set the numbers as hyperparameters. | matrix(x, y) | Function | Generate a matrix parameter with respective hyperparameter dimensions. | | vector(x) | Function | Generate a vector with respective hyperparameter size.| -Prefer using hyperparameters for matrix and vector creation, as these transfer their names to respective +Prefer using hyperparameters (set via `.config`) for matrix and vector creation, as these transfer their names to respective dimensions for error checking. For `dropout,matrix,vector` you can also use the short names `drop,mat,vec`. ## Save and load architectures