Skip to content

Commit

Permalink
refactor namespaces, simplify functions in CUDARuntime
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasstadler authored Mar 9, 2020
1 parent e485975 commit 8cea020
Show file tree
Hide file tree
Showing 21 changed files with 447 additions and 492 deletions.
35 changes: 24 additions & 11 deletions projects/com.nvidia.grcuda/src/com/nvidia/grcuda/GrCUDAContext.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2020, Oracle and/or its affiliates. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -33,11 +33,11 @@

import org.graalvm.options.OptionKey;

import com.nvidia.grcuda.cuml.CUMLRegistry;
import com.nvidia.grcuda.functions.BindFunction;
import com.nvidia.grcuda.functions.BindKernelFunction;
import com.nvidia.grcuda.functions.BuildKernelFunction;
import com.nvidia.grcuda.functions.DeviceArrayFunction;
import com.nvidia.grcuda.functions.FunctionTable;
import com.nvidia.grcuda.functions.GetDeviceFunction;
import com.nvidia.grcuda.functions.GetDevicesFunction;
import com.nvidia.grcuda.gpu.CUDARuntime;
Expand All @@ -50,21 +50,34 @@
*/
public final class GrCUDAContext {

private static final String ROOT_NAMESPACE = "CU";

private final Env env;
private final CUDARuntime cudaRuntime;
private final FunctionTable functionTable = new FunctionTable().registerFunction(new BindFunction());
private final Namespace rootNamespace;
private final ArrayList<Runnable> disposables = new ArrayList<>();
private AtomicInteger moduleId = new AtomicInteger(0);
private boolean cudaInitialized = false;
private volatile boolean cudaInitialized = false;

public GrCUDAContext(Env env) {
this.env = env;
this.cudaRuntime = new CUDARuntime(this, env);
functionTable.registerFunction(new DeviceArrayFunction(cudaRuntime));
functionTable.registerFunction(new BindKernelFunction(cudaRuntime));
functionTable.registerFunction(new BuildKernelFunction(cudaRuntime));
functionTable.registerFunction(new GetDevicesFunction(cudaRuntime));
functionTable.registerFunction(new GetDeviceFunction(cudaRuntime));

Namespace namespace = new Namespace(ROOT_NAMESPACE);
namespace.addNamespace(namespace);
namespace.addFunction(new BindFunction());
namespace.addFunction(new DeviceArrayFunction(cudaRuntime));
namespace.addFunction(new BindKernelFunction(cudaRuntime));
namespace.addFunction(new BuildKernelFunction(cudaRuntime));
namespace.addFunction(new GetDevicesFunction(cudaRuntime));
namespace.addFunction(new GetDeviceFunction(cudaRuntime));
cudaRuntime.registerCUDAFunctions(namespace);
if (this.getOption(GrCUDAOptions.CuMLEnabled)) {
Namespace ml = new Namespace(CUMLRegistry.NAMESPACE);
namespace.addNamespace(ml);
new CUMLRegistry(this).registerCUMLFunctions(ml);
}
this.rootNamespace = namespace;
}

public Env getEnv() {
Expand All @@ -75,8 +88,8 @@ public CUDARuntime getCUDARuntime() {
return cudaRuntime;
}

public FunctionTable getFunctionTable() {
return functionTable;
public Namespace getRootNamespace() {
return rootNamespace;
}

public void addDisposable(Runnable disposable) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

import org.graalvm.options.OptionDescriptors;

import com.nvidia.grcuda.cuml.CUMLRegistry;
import com.nvidia.grcuda.nodes.ExpressionNode;
import com.nvidia.grcuda.nodes.GrCUDARootNode;
import com.nvidia.grcuda.parser.ParserAntlr;
Expand All @@ -49,12 +48,10 @@ public final class GrCUDALanguage extends TruffleLanguage<GrCUDAContext> {

@Override
protected GrCUDAContext createContext(Env env) {
GrCUDAContext context = new GrCUDAContext(env);
context.getCUDARuntime().registerCUDAFunctions(context.getFunctionTable());
if (context.getOption(GrCUDAOptions.CuMLEnabled)) {
new CUMLRegistry(context).registerCUMLFunctions(context.getFunctionTable());
if (!env.isNativeAccessAllowed()) {
throw new GrCUDAException("cannot create CUDA context without native access");
}
return context;
return new GrCUDAContext(env);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,16 +180,4 @@ void writeArrayElement(long index, Object value,
throw new IllegalStateException("tried to write non-last dimension in MultiDimDeviceArrayView");
}
}

@ExportMessage
@SuppressWarnings("static-method")
boolean hasMembers() {
return false;
}

@ExportMessage
@SuppressWarnings("static-method")
Object getMembers(@SuppressWarnings("unused") boolean includeInternal) {
return null;
}
}
145 changes: 145 additions & 0 deletions projects/com.nvidia.grcuda/src/com/nvidia/grcuda/Namespace.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019, 2020, Oracle and/or its affiliates. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of NVIDIA CORPORATION nor the names of its
* contributors may be used to endorse or promote products derived
* from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package com.nvidia.grcuda;

import java.util.Optional;
import java.util.TreeMap;

import com.nvidia.grcuda.DeviceArray.MemberSet;
import com.nvidia.grcuda.functions.Function;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.interop.ArityException;
import com.oracle.truffle.api.interop.InteropLibrary;
import com.oracle.truffle.api.interop.TruffleObject;
import com.oracle.truffle.api.interop.UnknownIdentifierException;
import com.oracle.truffle.api.interop.UnsupportedMessageException;
import com.oracle.truffle.api.interop.UnsupportedTypeException;
import com.oracle.truffle.api.library.CachedLibrary;
import com.oracle.truffle.api.library.ExportLibrary;
import com.oracle.truffle.api.library.ExportMessage;

/**
* A namespace that exposes a simple interface via {@link InteropLibrary}. It is immutable from the
* point of view of the {@link InteropLibrary interop} API.
*/
@ExportLibrary(InteropLibrary.class)
public final class Namespace implements TruffleObject {

private final TreeMap<String, Object> map = new TreeMap<>();

private final String name;

public Namespace(String name) {
this.name = name;
}

@Override
public String toString() {
return name == null ? "<root>" : name;
}

private void addInternal(String newName, Object newElement) {
if (newName == null || newName.isEmpty()) {
throw new GrCUDAInternalException("cannot add elmenelement with name '" + newName + "' in namespace '" + name + "'");
}
if (map.containsKey(newName)) {
throw new GrCUDAInternalException("'" + newName + "' already exists in namespace '" + name + "'");
}
map.put(newName, newElement);
}

public void addFunction(Function function) {
addInternal(function.getName(), function);
}

public void addNamespace(Namespace namespace) {
addInternal(namespace.name, namespace);
}

@TruffleBoundary
public Optional<Object> lookup(String... path) {
if (path.length == 0) {
return Optional.empty();
}
return lookup(0, path);
}

private Optional<Object> lookup(int pos, String[] path) {
Object entry = map.get(path[pos]);
if (entry == null) {
return Optional.empty();
}
if (pos + 1 == path.length) {
return Optional.of(entry);
} else {
return entry instanceof Namespace ? ((Namespace) entry).lookup(pos + 1, path) : Optional.empty();
}
}

@SuppressWarnings("static-method")
@ExportMessage
public boolean hasMembers() {
return true;
}

@ExportMessage
@TruffleBoundary
public Object getMembers(@SuppressWarnings("unused") boolean includeInternal) {
return new MemberSet(map.keySet().toArray(new String[0]));
}

@ExportMessage
@TruffleBoundary
public boolean isMemberReadable(String member) {
return map.containsKey(member);
}

@ExportMessage
@TruffleBoundary
public Object readMember(String member) throws UnknownIdentifierException {
Object entry = map.get(member);
if (entry == null) {
throw UnknownIdentifierException.create(member);
}
return entry;
}

@ExportMessage
@TruffleBoundary
public boolean isMemberInvocable(String member) {
return map.get(member) instanceof Function;
}

@ExportMessage
public Object invokeMember(String member, Object[] arguments,
@CachedLibrary(limit = "2") InteropLibrary callLibrary)
throws UnsupportedMessageException, ArityException, UnknownIdentifierException, UnsupportedTypeException {
return callLibrary.execute(readMember(member), arguments);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@
import com.nvidia.grcuda.GrCUDAException;
import com.nvidia.grcuda.GrCUDAInternalException;
import com.nvidia.grcuda.GrCUDAOptions;
import com.nvidia.grcuda.Namespace;
import com.nvidia.grcuda.functions.ExternalFunctionFactory;
import com.nvidia.grcuda.functions.Function;
import com.nvidia.grcuda.functions.FunctionTable;
import com.nvidia.grcuda.gpu.UnsafeHelper;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
Expand Down Expand Up @@ -89,7 +89,7 @@ private void ensureInitialized() {

// create wrapper for cumlCreate: cumlError_t cumlCreate(int* handle) -> int
// cumlCreate()
cumlCreateFunction = new Function(CUMLFunctionNFI.CUML_CUMLCREATE.getFunctionFactory().getName(), NAMESPACE) {
cumlCreateFunction = new Function(CUMLFunctionNFI.CUML_CUMLCREATE.getFunctionFactory().getName()) {
@Override
@TruffleBoundary
public Object call(Object[] arguments) throws ArityException {
Expand All @@ -107,7 +107,7 @@ public Object call(Object[] arguments) throws ArityException {

// create wrapper for cumlDestroy: cumlError_t cumlDestroy(int handle) -> void
// cumlDestroy(int handle)
cumlDestroyFunction = new Function(CUMLFunctionNFI.CUML_CUMLDESTROY.getFunctionFactory().getName(), NAMESPACE) {
cumlDestroyFunction = new Function(CUMLFunctionNFI.CUML_CUMLDESTROY.getFunctionFactory().getName()) {
@Override
@TruffleBoundary
public Object call(Object[] arguments) throws ArityException, UnsupportedTypeException {
Expand Down Expand Up @@ -148,13 +148,13 @@ private void cuMLShutdown() {
}
}

public void registerCUMLFunctions(FunctionTable functionTable) {
public void registerCUMLFunctions(Namespace namespace) {
// Create function wrappers (decorators for all functions except handle con- and
// destruction)
List<CUMLFunctionNFI> hiddenFunctions = Arrays.asList(CUMLFunctionNFI.CUML_CUMLCREATE, CUMLFunctionNFI.CUML_CUMLDESTROY);
EnumSet.allOf(CUMLFunctionNFI.class).stream().filter(func -> !hiddenFunctions.contains(func)).forEach(func -> {
final ExternalFunctionFactory factory = func.getFunctionFactory();
final Function wrapperFunction = new Function(factory.getName(), NAMESPACE) {
final Function wrapperFunction = new Function(factory.getName()) {

private Function nfiFunction;

Expand Down Expand Up @@ -183,7 +183,7 @@ public Object call(Object[] arguments) {
}
}
};
functionTable.registerFunction(wrapperFunction);
namespace.addFunction(wrapperFunction);
});
}

Expand Down Expand Up @@ -215,20 +215,10 @@ private static String cumlReturnCodeToString(int returnCode) {
}

public enum CUMLFunctionNFI {
CUML_CUMLCREATE(
new ExternalFunctionFactory("cumlCreate",
NAMESPACE, "cumlCreate", "(pointer): sint32")),
CUML_CUMLDESTROY(
new ExternalFunctionFactory("cumlDestroy",
NAMESPACE, "cumlDestroy", "(sint32): sint32")),
CUML_DBSCANFITDOUBLE(
new ExternalFunctionFactory("cumlDpDbscanFit",
NAMESPACE, "cumlDpDbscanFit",
"(sint32, pointer, sint32, sint32, double, sint32, pointer, uint64, sint32): sint32")),
CUML_DBSCANFITFLOAT(
new ExternalFunctionFactory("cumlSpDbscanFit",
NAMESPACE, "cumlSpDbscanFit",
"(sint32, pointer, sint32, sint32, float, sint32, pointer, uint64, sint32): sint32"));
CUML_CUMLCREATE(new ExternalFunctionFactory("cumlCreate", "cumlCreate", "(pointer): sint32")),
CUML_CUMLDESTROY(new ExternalFunctionFactory("cumlDestroy", "cumlDestroy", "(sint32): sint32")),
CUML_DBSCANFITDOUBLE(new ExternalFunctionFactory("cumlDpDbscanFit", "cumlDpDbscanFit", "(sint32, pointer, sint32, sint32, double, sint32, pointer, uint64, sint32): sint32")),
CUML_DBSCANFITFLOAT(new ExternalFunctionFactory("cumlSpDbscanFit", "cumlSpDbscanFit", "(sint32, pointer, sint32, sint32, float, sint32, pointer, uint64, sint32): sint32"));

private final ExternalFunctionFactory factory;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
public final class BindFunction extends Function {

public BindFunction() {
super("bind", "");
super("bind");
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2020, Oracle and/or its affiliates. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -38,7 +38,7 @@ public final class BindKernelFunction extends Function {
private final CUDARuntime cudaRuntime;

public BindKernelFunction(CUDARuntime cudaRuntime) {
super("bindkernel", "");
super("bindkernel");
this.cudaRuntime = cudaRuntime;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2020, Oracle and/or its affiliates. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -37,7 +37,7 @@ public class BuildKernelFunction extends Function {
private final CUDARuntime cudaRuntime;

public BuildKernelFunction(CUDARuntime cudaRuntime) {
super("buildkernel", "");
super("buildkernel");
this.cudaRuntime = cudaRuntime;
}

Expand Down
Loading

0 comments on commit 8cea020

Please sign in to comment.