Skip to content

Commit

Permalink
Wasm code generation performance optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
mirkosertic committed Apr 19, 2024
1 parent 179ac88 commit a43f6a9
Show file tree
Hide file tree
Showing 14 changed files with 115 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import de.mirkosertic.bytecoder.core.backend.wasm.ast.ExportableFunction;
import de.mirkosertic.bytecoder.core.backend.wasm.ast.Exporter;
import de.mirkosertic.bytecoder.core.backend.wasm.ast.Function;
import de.mirkosertic.bytecoder.core.backend.wasm.ast.FunctionIndex;
import de.mirkosertic.bytecoder.core.backend.wasm.ast.FunctionType;
import de.mirkosertic.bytecoder.core.backend.wasm.ast.FunctionsSection;
import de.mirkosertic.bytecoder.core.backend.wasm.ast.Global;
Expand Down Expand Up @@ -584,6 +585,8 @@ public WasmType apply(final Type argument) {
final OpaqueTypesAdapterMethods adapterMethods = new OpaqueTypesAdapterMethods();
final GeneratedMethodsRegistry generatedMethodsRegistry = new GeneratedMethodsRegistry();

final FunctionIndex functionIndex = module.functionIndex();

for (final ResolvedClass cl : resolvedClasses) {
// Class objects for

Expand Down Expand Up @@ -656,7 +659,8 @@ public WasmType apply(final Type argument) {
compileUnit,
objectTypeMappings,
rtTypeMappings,
ConstExpressions.call(delegateFunction, callArgs)
ConstExpressions.call(delegateFunction, callArgs),
functionIndex
));
break;
}
Expand Down Expand Up @@ -795,7 +799,8 @@ public WasmType apply(final Type argument) {
compileUnit,
objectTypeMappings,
rtTypeMappings,
ConstExpressions.call(delegateFunction, callArgs)));
ConstExpressions.call(delegateFunction, callArgs),
functionIndex));
break;
}
default: {
Expand Down Expand Up @@ -825,7 +830,7 @@ public WasmType apply(final Type argument) {
final DominatorTree dt = new DominatorTree(g);

try {
new Sequencer(g, dt, new WasmStructuredControlflowCodeGenerator(compileUnit, module, rtTypeMappings, objectTypeMappings, implFunction, toWASMType, toFunctionType, methodToIDMapper, g, resolvedClasses, vTableResolver, generatedMethodsRegistry));
new Sequencer(g, dt, new WasmStructuredControlflowCodeGenerator(compileUnit, module, rtTypeMappings, objectTypeMappings, implFunction, toWASMType, toFunctionType, methodToIDMapper, g, resolvedClasses, vTableResolver, generatedMethodsRegistry, functionIndex));
} catch (final CodeGenerationFailure e) {
throw e;
} catch (final RuntimeException e) {
Expand Down Expand Up @@ -868,7 +873,7 @@ public WasmType apply(final Type argument) {
"typeId"
)
);
initArgs.add(ConstExpressions.ref.ref(module.functionIndex().firstByLabel(WasmHelpers.generateClassName(stringClass.type) + "_vt")));
initArgs.add(ConstExpressions.ref.ref(functionIndex.findByLabel(WasmHelpers.generateClassName(stringClass.type) + "_vt")));
initArgs.add(ConstExpressions.getLocal(newStringFunction.localByLabel("str")));

final Global stringGlobal = module.getGlobals().globalsIndex().globalByLabel(WasmHelpers.generateClassName(stringClass.type) + "_cls");
Expand Down Expand Up @@ -909,7 +914,7 @@ public WasmType apply(final Type argument) {
ConstExpressions.call(stringInitFunction, new ArrayList<>()),
"factoryFor"
));
initArgs.add(ConstExpressions.ref.ref(module.functionIndex().firstByLabel(WasmHelpers.generateClassName(stringClass.type) + "_vt")));
initArgs.add(ConstExpressions.ref.ref(functionIndex.findByLabel(WasmHelpers.generateClassName(stringClass.type) + "_vt")));

final Global stringGlobal = module.getGlobals().globalsIndex().globalByLabel(WasmHelpers.generateClassName(stringClass.type) + "_cls");

Expand Down Expand Up @@ -994,7 +999,8 @@ public WasmType apply(final Type argument) {
compileUnit,
objectTypeMappings,
rtTypeMappings,
ConstExpressions.getLocal(callback.localByLabel("event"))
ConstExpressions.getLocal(callback.localByLabel("event")),
functionIndex
));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import de.mirkosertic.bytecoder.core.backend.wasm.ast.Container;
import de.mirkosertic.bytecoder.core.backend.wasm.ast.ExportableFunction;
import de.mirkosertic.bytecoder.core.backend.wasm.ast.Expressions;
import de.mirkosertic.bytecoder.core.backend.wasm.ast.FunctionIndex;
import de.mirkosertic.bytecoder.core.backend.wasm.ast.FunctionType;
import de.mirkosertic.bytecoder.core.backend.wasm.ast.Global;
import de.mirkosertic.bytecoder.core.backend.wasm.ast.HostType;
Expand Down Expand Up @@ -240,6 +241,8 @@ public NestingLevelTry(final NestingLevel<?> parent, final Expressions activeFlo

private final GeneratedMethodsRegistry generatedMethodsRegistry;

private final FunctionIndex functionIndex;

public WasmStructuredControlflowCodeGenerator(final CompileUnit compileUnit, final Module module,
final Map<ResolvedClass, StructType> rtMappings,
final Map<ResolvedClass, StructType> objectTypeMappings,
Expand All @@ -250,7 +253,8 @@ public WasmStructuredControlflowCodeGenerator(final CompileUnit compileUnit, fin
final Graph graph,
final List<ResolvedClass> resolvedClasses,
final VTableResolver vTableResolver,
final GeneratedMethodsRegistry generatedMethodsRegistry) {
final GeneratedMethodsRegistry generatedMethodsRegistry,
final FunctionIndex functionIndex) {
this.compileUnit = compileUnit;
this.module = module;
this.exportableFunction = exportableFunction;
Expand All @@ -265,6 +269,7 @@ public WasmStructuredControlflowCodeGenerator(final CompileUnit compileUnit, fin
this.resolvedClasses = resolvedClasses;
this.vTableResolver = vTableResolver;
this.generatedMethodsRegistry = generatedMethodsRegistry;
this.functionIndex = functionIndex;
}

@Override
Expand Down Expand Up @@ -463,7 +468,8 @@ public static WasmValue createNewInstanceOf(final Type instanceType,
final CompileUnit compileUnit,
final Map<ResolvedClass, StructType> objectTypeMappings,
final Map<ResolvedClass, StructType> rtMappings,
final WasmValue externRef) {
final WasmValue externRef,
final FunctionIndex functionIndex) {
final ResolvedClass cl = compileUnit.findClass(instanceType);
if (cl == null) {
throw new IllegalArgumentException("Cannot find resolved class for " + instanceType);
Expand All @@ -485,7 +491,7 @@ public static WasmValue createNewInstanceOf(final Type instanceType,
)
);

initArgs.add(ConstExpressions.ref.ref(module.functionIndex().firstByLabel(WasmHelpers.generateClassName(cl.type) + "_vt")));
initArgs.add(ConstExpressions.ref.ref(functionIndex.findByLabel(WasmHelpers.generateClassName(cl.type) + "_vt")));

initArgs.add(externRef);

Expand Down Expand Up @@ -530,7 +536,7 @@ public static WasmValue createNewInstanceOf(final Type instanceType,

private WasmValue toWasmValue(final New value) {
return createNewInstanceOf(value.type,
module, compileUnit, objectTypeMappings, rtMappings, ConstExpressions.ref.externNullRef());
module, compileUnit, objectTypeMappings, rtMappings, ConstExpressions.ref.externNullRef(), functionIndex);
}

private WasmValue toWasmValue(final ReadInstanceField value) {
Expand Down Expand Up @@ -1102,7 +1108,7 @@ private WasmValue generateInvokeDynamicLambdaMetaFactoryInvocation(final InvokeD

lambdaMethod.flow.setLocal(objInstance,
createNewInstanceOf(implementationMethod.owner.type,
module, compileUnit, objectTypeMappings, rtMappings, ConstExpressions.ref.externNullRef()));
module, compileUnit, objectTypeMappings, rtMappings, ConstExpressions.ref.externNullRef(), functionIndex));

arguments.add(0, ConstExpressions.getLocal(objInstance));

Expand Down Expand Up @@ -1249,7 +1255,7 @@ public void generateCode(final PrintWriter pw, final int index) {
ConstExpressions.getGlobal(stringGlobal),
"factoryFor"
));
initArgs.add(ConstExpressions.ref.ref(module.functionIndex().firstByLabel(WasmHelpers.generateClassName(stringClass.type) + "_vt")));
initArgs.add(ConstExpressions.ref.ref(functionIndex.findByLabel(WasmHelpers.generateClassName(stringClass.type) + "_vt")));

initArgs.add(ConstExpressions.call(concatFunction, arguments));

Expand Down Expand Up @@ -1386,7 +1392,7 @@ public void generateCode(final PrintWriter pw, final int index) {
ConstExpressions.getGlobal(stringGlobal),
"factoryFor"
));
initArgs.add(ConstExpressions.ref.ref(module.functionIndex().firstByLabel(WasmHelpers.generateClassName(stringClass.type) + "_vt")));
initArgs.add(ConstExpressions.ref.ref(functionIndex.findByLabel(WasmHelpers.generateClassName(stringClass.type) + "_vt")));

initArgs.add(ConstExpressions.call(concatFunction, arguments));

Expand Down Expand Up @@ -1792,7 +1798,7 @@ private WasmValue toWasmValue(final NewArray value) {
final Global global = module.getGlobals().globalsIndex().globalByLabel(arrayClsName + "_cls");

initArguments.add(ConstExpressions.i32.c(resolvedClasses.indexOf(arrayCls)));
initArguments.add(ConstExpressions.ref.ref(module.functionIndex().firstByLabel(WasmHelpers.generateClassName(arrayClass) + "_vt")));
initArguments.add(ConstExpressions.ref.ref(functionIndex.findByLabel(WasmHelpers.generateClassName(arrayClass) + "_vt")));
initArguments.add(ConstExpressions.ref.externNullRef());
initArguments.add(ConstExpressions.struct.get(
rtMappings.get(arrayCls),
Expand Down Expand Up @@ -2058,7 +2064,8 @@ private WasmValue toWasmValue(final Cast value) {
objectTypeMappings.get(compileUnit.findClass(Type.getType(Object.class))),
toWasmValue((Value) value.incomingDataFlows[0]),
"nativeObject"
)
),
functionIndex
);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ public Local newLocal(final String label, final WasmType type) {
}

@Override
public void writeTo(final TextWriter textWriter, final Module aModule) throws IOException {
public void writeTo(final TextWriter textWriter, final Module aModule, final WasmValue.ExportContext exportContext) throws IOException {
textWriter.opening();
textWriter.write("func");
textWriter.space();
Expand Down Expand Up @@ -145,7 +145,7 @@ public void writeTo(final TextWriter textWriter, final Module aModule) throws IO
local.writeTo(textWriter);
textWriter.newLine();
}
final DefaultExportContext context = new DefaultExportContext(this, getModule().functionIndex());
final DefaultExportContext context = new DefaultExportContext(this, exportContext.functionIndex());
for (final WasmExpression expression : getChildren()) {
expression.writeTo(textWriter, context);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ public void writeTo(final TextWriter textWriter) {
}
}

public void writeTo(final BinaryWriter binaryWriter, final List<Memory> memoryIndex) throws IOException {
final FunctionIndex functionIndex = getModule().functionIndex();
public void writeTo(final BinaryWriter binaryWriter, final List<Memory> memoryIndex, final WasmValue.ExportContext exportContext) throws IOException {
final TagIndex tagIndex = getModule().tagIndex();
try (final BinaryWriter.SectionWriter exportWriter = binaryWriter.exportsSection()) {
exportWriter.writeUnsignedLeb128(exports.size());
Expand All @@ -56,7 +55,7 @@ public void writeTo(final BinaryWriter binaryWriter, final List<Memory> memoryIn
final Exportable value = entry.getValue();
if (value instanceof ExportableFunction) {
exportWriter.writeByte(ExternalKind.EXTERNAL_KIND_FUNCTION);
exportWriter.writeUnsignedLeb128(functionIndex.indexOf((Function) value));
exportWriter.writeUnsignedLeb128(exportContext.functionIndex().indexOf((Function) value));
} else if (value instanceof Memory) {
exportWriter.writeByte(ExternalKind.EXTERNAL_KIND_MEMORY);
exportWriter.writeUnsignedLeb128(memoryIndex.indexOf(value));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public class Function extends Container implements Importable, Callable {
}

@Override
public void writeTo(final TextWriter textWriter, final Module aModule) throws IOException {
public void writeTo(final TextWriter textWriter, final Module aModule, final WasmValue.ExportContext exportContext) throws IOException {
textWriter.opening();
textWriter.write("func");
textWriter.space();
Expand Down Expand Up @@ -96,6 +96,6 @@ public WasmType resolveResultType(final WasmValue.ExportContext context) {

@Override
public int resolveIndex(final WasmValue.ExportContext context) {
return module.functionIndex().indexOf(this);
return context.functionIndex().indexOf(this);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,19 @@
package de.mirkosertic.bytecoder.core.backend.wasm.ast;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class FunctionIndex {

private final List<Function> functions;
private final Map<String, Function> labelToFunction;

FunctionIndex() {
functions = new ArrayList<>();
labelToFunction = new HashMap<>();
}

public int size() {
Expand All @@ -37,6 +41,7 @@ public Function get(final int aIndex) {

public void add(final Function function) {
functions.add(function);
labelToFunction.put(function.getLabel().toLowerCase(), function);
}

public int indexOf(final Function value) {
Expand All @@ -47,12 +52,12 @@ public List<ExportableFunction> exportableFunctions() {
return functions.stream().filter(t -> t instanceof ExportableFunction).map(t -> (ExportableFunction) t).collect(Collectors.toList());
}

public <T extends Function> T firstByLabel(final String label) {
for (final Function function : functions) {
if (label.equalsIgnoreCase(function.getLabel())) {
return (T) function;
}
public <T extends Function> T findByLabel(final String label) {
final String key = label.toLowerCase();
final T result = (T) labelToFunction.get(key);
if (result == null) {
throw new IllegalArgumentException("No such method : " + label);
}
throw new IllegalArgumentException("No such method : " + label);
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ public ExportableFunction newFunction(final String label) {
return function;
}

public void writeTo(final TextWriter textWriter) throws IOException {
public void writeTo(final TextWriter textWriter, final WasmValue.ExportContext exportContext) throws IOException {
for (final Function function : functions) {
function.writeTo(textWriter, getModule());
function.writeTo(textWriter, getModule(), exportContext);
textWriter.newLine();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
import java.io.IOException;

public interface Importable {
void writeTo(TextWriter textWriter, Module aModule) throws IOException;
void writeTo(TextWriter textWriter, Module aModule, final WasmValue.ExportContext exportContext) throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public Function importFunction(final ImportReference importReference, final Stri
return function;
}

public void writeTo(final TextWriter textWriter) throws IOException {
public void writeTo(final TextWriter textWriter, final WasmValue.ExportContext exportContext) throws IOException {
for (final ImportEntry entry : imports) {

final ImportReference ref = entry.getReference();
Expand All @@ -81,7 +81,7 @@ public void writeTo(final TextWriter textWriter) throws IOException {
textWriter.space();
textWriter.writeText(ref.getObjectName());
textWriter.space();
entry.getImportable().writeTo(textWriter, getModule());
entry.getImportable().writeTo(textWriter, getModule(), exportContext);
textWriter.closing();
textWriter.newLine();
}
Expand Down
Loading

0 comments on commit a43f6a9

Please sign in to comment.