Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #448: Transform global variables in stats package #449

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ atlassian-ide-plugin.xml
*.o
*.gimple
.Rproj.user
.gradle
3 changes: 3 additions & 0 deletions packages/stats/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@
<goal>gnur-sources-compile</goal>
</goals>
<phase>compile</phase>
<configuration>
<transformGlobalVariables>true</transformGlobalVariables>
</configuration>
</execution>
<execution>
<id>build</id>
Expand Down
4 changes: 3 additions & 1 deletion packages/stats/src/main/c/apply_optim.c
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,9 @@ void cgmin(int n, double *Bvec, double *X, double *Fmin,
}

/* include setulb() */
#include "lbfgsb.c"
void setulb(int n, int m, double *x, double *l, double *u, int *nbd,
double *f, double *g, double factr, double *pgtol,
double *wa, int * iwa, char *task, int iprint, int *isave);

void lbfgsb(int n, int m, double *x, double *l, double *u, int *nbd,
double *Fmin, optimfn fminfn, optimgr fmingr, int *fail,
Expand Down
1 change: 0 additions & 1 deletion packages/stats/src/main/c/lbfgsb.c
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ static void prn3lb(int n, double *x, double *f, char *task, int iprint,


/* ================ L-BFGS-B (version 2.3) ========================== */
static
void setulb(int n, int m, double *x, double *l, double *u, int *nbd,
double *f, double *g, double factr, double *pgtol,
double *wa, int * iwa, char *task, int iprint, int *isave)
Expand Down
69 changes: 69 additions & 0 deletions packages/stats/src/main/c/lbfgsb.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@

static void active(int, double *, double *, int *, double *, int *,
int, int *, int *, int *);
static void bmv(int, double *, double *, int *, double *, double *, int *);
static void cauchy(int, double *, double *,
double *, int *, double *, int *, int *,
double *, double *, double *, int, double *,
double *, double *, double *, double *, int * ,
int *, double *, double *, double *, double *,
int *, int, double *, int *, double *);
static void cmprlb(int, int, double *,
double *, double *, double *, double *,
double *, double *, double *, double *, int *,
double *, int *, int *, int *, int *,
int *);
static void dcsrch(double *, double *, double *,
double, double, double,
double, double, char *);
static void dcstep(double *, double *,
double *, double *, double *, double *,
double *, double *, double *, int *, double *,
double *);
static void errclb(int, int, double,
double *, double *, int *, char *, int *, int *);
static void formk(int, int *, int *, int *, int *, int *, int *,
int *, double *, double *, int, double *,
double *, double *, double *, int *, int *, int *);
static void formt(int, double *, double *,
double *, int *, double *, int *);
static void freev(int, int *, int *,
int *, int *, int *, int *, int *, int *,
int *, int, int *);
static void hpsolb(int, double *, int *, int);
static void lnsrlb(int, double *, double *,
int *, double *, double *, double *, double *,
double *, double *, double *, double *,
double *, double *, double *, double *,
double *, double *, double *, int *, int *,
int *, int *, int *, char *, int *, int *,
char *);
static void mainlb(int, int, double *,
double *, double *, int *, double *, double *,
double, double *, double *, double *,
double *, double *, double *, double *,
double *, double *, double *, double *,
double *, double *, int *, int *, int *, char *,
int, char *, int *);
static void matupd(int, int, double *, double *, double *,
double *, double *, double *, int *, int *,
int *, int *, double *, double *, double *,
double *, double *);
static void projgr(int, double *, double *,
int *, double *, double *, double *);
static void subsm(int, int, int *, int *, double *, double *,
int *, double *, double *, double *, double *,
double *, int *, int *, int *, double *,
double *, int, int *);

static void prn1lb(int n, int m, double *l, double *u, double *x,
int iprint, double epsmch);
static void prn2lb(int n, double *x, double *f, double *g, int iprint,
int iter, int nfgv, int nact, double sbgnrm,
int nint, char *word, int iword, int iback,
double stp, double xstep);
static void prn3lb(int n, double *x, double *f, char *task, int iprint,
int info, int iter, int nfgv, int nintol, int nskip,
int nact, double sbgnrm, int nint,
char *word, int iback, double stp, double xstep,
int k);
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Renjin : JVM-based interpreter for the R language for the statistical analysis
* Copyright © 2010-${$file.lastModified.year} BeDataDriven Groep B.V. and contributors
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, a copy is available at
* https://www.gnu.org/licenses/gpl-2.0.txt
*
*/

package org.renjin.script;

import org.junit.Test;
import org.renjin.sexp.DoubleVector;
import org.renjin.sexp.ListVector;

import javax.script.ScriptEngine;

public class OptimMultiThreadTest {
private static final ThreadLocal<ScriptEngine> engines = new ThreadLocal<>();
private static final int RUNS = 20;
private static final int THREADS_NUMBER = 2;


@Test
public void testMultithread() throws InterruptedException {
Thread[] threads = new Thread[THREADS_NUMBER];
for (int i = 0; i < THREADS_NUMBER; i++) {
threads[i] = new Thread(this::task);
threads[i].start();
}
for (int i = 0; i < THREADS_NUMBER; i++) {
threads[i].join();

}
}

private void task() {
ScriptEngine engine = getEngine();
try {
// Rosenbrock Banana function
engine.eval("fr <- function(x) { " +
" x1 <- x[1]\n" +
" x2 <- x[2]\n" +
" 100 * (x2 - x1 * x1) ^ 2 + (1 - x1) ^ 2\n" +
"}");
} catch (Exception e) {
System.err.println(e.getMessage());
}
for (int i = 0; i < RUNS; i++) {
try {
ListVector result = (ListVector) engine.eval("optim(c(-1.2,1), fr, method = \"L-BFGS\")");
System.out.println(result);

DoubleVector parameters = (DoubleVector) result.get("par");
double p0 = parameters.getElementAsDouble(0);
double p1 = parameters.getElementAsDouble(1);
if(Math.abs(p0 - 0.9998000) > 0.00001) {
throw new AssertionError("Incorrect result: p0 = " + String.format("%f", p0));
}
if(Math.abs(p1 - 0.9996001) > 0.00001) {
throw new AssertionError("Incorrect result: p1 = " + String.format("%f", p1));
}
} catch (Exception e) {
System.err.println(e.getMessage());
}
}
}

private ScriptEngine getEngine() {
ScriptEngine engine = engines.get();
if (engine == null) {
RenjinScriptEngineFactory factory = new RenjinScriptEngineFactory();
engine = factory.getScriptEngine();
engines.set(engine);
}
return engine;
}

}
23 changes: 23 additions & 0 deletions tests/src/test/R/test.format.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#
# Renjin : JVM-based interpreter for the R language for the statistical analysis
# Copyright © 2010-2019 BeDataDriven Groep B.V. and contributors
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, a copy is available at
# https://www.gnu.org/licenses/gpl-2.0.txt
#
library(hamcrest)

test.format.double <- function() {
assertThat(format(1/3), equalTo("0.3333333"))
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ public class FunctionGenerator implements InvocationStrategy {
private TypeOracle typeOracle;
private FunctionOracle functionOracle;
private ExprFactory exprFactory;
private List<GlobalVarTransformer> globalVarTransformers;
private LocalStaticVarAllocator staticVarAllocator;
private LocalVariableTable localSymbolTable;
private LocalVariableTable localStaticSymbolTable;
Expand All @@ -85,7 +86,9 @@ public class FunctionGenerator implements InvocationStrategy {
private boolean compilationFailed = false;

public FunctionGenerator(String className, GimpleFunction function, TypeOracle typeOracle,
GlobalVarAllocator globalVarAllocator, UnitSymbolTable symbolTable) {
GlobalVarAllocator globalVarAllocator,
List<GlobalVarTransformer> globalVarTransformers,
UnitSymbolTable symbolTable) {
this.className = className;
this.function = function;
this.typeOracle = typeOracle;
Expand All @@ -103,6 +106,7 @@ public FunctionGenerator(String className, GimpleFunction function, TypeOracle t
this.staticVarAllocator = new LocalStaticVarAllocator("$" + function.getSafeMangledName() + "$", globalVarAllocator);
this.localSymbolTable = new LocalVariableTable(symbolTable);
this.localStaticSymbolTable = new LocalVariableTable(symbolTable);
this.globalVarTransformers = globalVarTransformers;

}

Expand Down Expand Up @@ -375,15 +379,14 @@ private void scheduleLocalVariables() {
}

try {
GExpr generator = functionOracle.variable(varDecl,
varDecl.isStatic() ?
staticVarAllocator :
mv.getLocalVarAllocator());

localSymbolTable.addVariable(varDecl.getId(), generator);

if(varDecl.isStatic()) {
GExpr generator = localStaticGenerator(varDecl);
localSymbolTable.addVariable(varDecl.getId(), generator);
localStaticSymbolTable.addVariable(varDecl.getId(), generator);

} else {
localSymbolTable.addVariable(varDecl.getId(),
functionOracle.variable(varDecl, mv.getLocalVarAllocator()));
}

} catch (Exception e) {
Expand All @@ -392,6 +395,15 @@ private void scheduleLocalVariables() {
}
}

private GExpr localStaticGenerator(GimpleVarDecl varDecl) {
for (GlobalVarTransformer globalVarTransformer : globalVarTransformers) {
if(globalVarTransformer.acceptLocalStaticVar(varDecl)) {
return globalVarTransformer.generator(typeOracle, function, varDecl);
}
}
return functionOracle.variable(varDecl, staticVarAllocator);
}

private void emitBasicBlock(GimpleBasicBlock basicBlock) {
mv.visitLabel(labels.of(basicBlock));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,17 @@
import org.renjin.gcc.codegen.expr.GExpr;
import org.renjin.gcc.codegen.type.TypeOracle;
import org.renjin.gcc.gimple.GimpleCompilationUnit;
import org.renjin.gcc.gimple.GimpleFunction;
import org.renjin.gcc.gimple.GimpleVarDecl;

public interface GlobalVarTransformer {
boolean accept(GimpleVarDecl decl);

boolean acceptGlobalVar(GimpleVarDecl decl);

boolean acceptLocalStaticVar(GimpleVarDecl decl);

GExpr generator(TypeOracle typeOracle, GimpleCompilationUnit unit, GimpleVarDecl decl);

GExpr generator(TypeOracle typeOracle, GimpleFunction function, GimpleVarDecl decl);

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.renjin.gcc.codegen.expr.GExpr;
import org.renjin.gcc.codegen.type.TypeOracle;
import org.renjin.gcc.gimple.GimpleCompilationUnit;
import org.renjin.gcc.gimple.GimpleFunction;
import org.renjin.gcc.gimple.GimpleVarDecl;

import java.util.Map;
Expand All @@ -37,12 +38,22 @@ public ProvidedVarTransformer(TypeOracle typeOracle, Map<String, ProvidedGlobalV
}

@Override
public boolean accept(GimpleVarDecl decl) {
public boolean acceptGlobalVar(GimpleVarDecl decl) {
return decl.isPublic() && providedVariables.containsKey(decl.getName());
}

@Override
public boolean acceptLocalStaticVar(GimpleVarDecl decl) {
return false;
}

@Override
public GExpr generator(TypeOracle typeOracle, GimpleCompilationUnit unit, GimpleVarDecl decl) {
return providedVariables.get(decl.getName()).createExpr(decl, this.typeOracle);
}

@Override
public GExpr generator(TypeOracle typeOracle, GimpleFunction function, GimpleVarDecl decl) {
throw new UnsupportedOperationException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ public UnitClassGenerator(TypeOracle typeOracle,
if (!isExcluded(function)) {
try {
symbolTable.addFunction(function,
new FunctionGenerator(className, function, typeOracle, globalVarAllocator, symbolTable));
new FunctionGenerator(className, function, typeOracle, globalVarAllocator, globalVarTransformers, symbolTable));
} catch (Exception e) {
throw new InternalCompilerException(String.format("Exception creating %s for %s in %s: %s",
FunctionGenerator.class.getSimpleName(),
Expand Down Expand Up @@ -135,7 +135,7 @@ private GExpr generatorForGlobalVar(List<GlobalVarTransformer> globalVarTransfor
GimpleVarDecl decl) {

for (GlobalVarTransformer transformer : globalVarTransformers) {
if(transformer.accept(decl)) {
if(transformer.acceptGlobalVar(decl)) {
return transformer.generator(typeOracle, this.unit, decl);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
*/
public class PointerPtr extends AbstractPtr {

private static final int BYTES = 4;
public static final int BYTES = 4;

public static final PointerPtr NULL = new PointerPtr(null, 0);

Expand Down
Loading