Skip to content

Commit

Permalink
performance tweaks for multidimensional arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasstadler committed Sep 13, 2019
1 parent b9d36e2 commit 02193bb
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,15 @@ public class MultiDimDeviceArray implements TruffleObject {
/** Number of elements in each dimension. */
private final long[] elementsPerDimension;

/** Stride in each dimension. */
private final long[] stridePerDimension;

/** Total number of elements stored in the array. */
private final long totalElementCount;

/** true if data is stored in column-major format (Fortran), false row-major (C). */
private boolean columnMajor;

private final long stride;

/** Mutable view onto the underlying memory buffer. */
private final LittleEndianNativeArrayView nativeView;

Expand All @@ -87,12 +88,29 @@ public MultiDimDeviceArray(CUDARuntime runtime, ElementType elementType, long[]
this.elementType = elementType;
this.elementsPerDimension = new long[dimensions.length];
System.arraycopy(dimensions, 0, this.elementsPerDimension, 0, dimensions.length);
this.stridePerDimension = computeStride(dimensions, columnMajor);
this.totalElementCount = prod;
this.columnMajor = useColumnMajor;
this.stride = computeStrideInDim(0);
this.nativeView = runtime.cudaMallocManaged(getSizeBytes());
}

private static long[] computeStride(long[] dimensions, boolean columnMajor) {
long prod = 1;
long[] stride = new long[dimensions.length];
if (columnMajor) {
for (int i = 0; i < dimensions.length; i++) {
stride[i] = prod;
prod *= dimensions[i];
}
} else {
for (int i = dimensions.length - 1; i >= 0; i--) {
stride[i] = prod;
prod *= dimensions[i];
}
}
return stride;
}

public final int getNumberDimensions() {
return elementsPerDimension.length;
}
Expand All @@ -105,12 +123,20 @@ public final long[] getShape() {

public final long getElementsInDimension(int dimension) {
if (dimension < 0 || dimension >= elementsPerDimension.length) {
throw new IllegalArgumentException("invalid dimension index " + dimension +
", valid [0, " + elementsPerDimension.length + ']');
CompilerDirectives.transferToInterpreter();
throw new IllegalArgumentException("invalid dimension index " + dimension + ", valid [0, " + elementsPerDimension.length + ']');
}
return elementsPerDimension[dimension];
}

public long getStrideInDimension(int dimension) {
if (dimension < 0 || dimension >= stridePerDimension.length) {
CompilerDirectives.transferToInterpreter();
throw new IllegalArgumentException("invalid dimension index " + dimension + ", valid [0, " + stridePerDimension.length + ']');
}
return stridePerDimension[dimension];
}

final boolean isIndexValidInDimension(long index, int dimension) {
long numElementsInDim = getElementsInDimension(dimension);
return (index > 0) && (index < numElementsInDim);
Expand Down Expand Up @@ -177,52 +203,14 @@ boolean isArrayElementReadable(long index) {
return index >= 0 && index < elementsPerDimension[0];
}

@ExportMessage
@SuppressWarnings("static-method")
boolean isArrayElementModifiable(@SuppressWarnings("unused") long index) {
return false;
}

@ExportMessage
@SuppressWarnings("static-method")
boolean isArrayElementInsertable(@SuppressWarnings("unused") long index) {
return false;
}

final long computeStrideInDim(int dim) {
long prod = 1;
if (columnMajor) {
for (int i = 0; i < dim; i++) {
prod *= elementsPerDimension[i];
}
} else {
for (int i = dim + 1; i < getNumberDimensions(); i++) {
prod *= elementsPerDimension[i];
}
}
return prod;
}

@ExportMessage
Object readArrayElement(long index) throws InvalidArrayIndexException {
// System.out.println("MultiDimDeviceArray::readArrayElement(" + index + ')');
if ((index < 0) || (index >= elementsPerDimension[0])) {
CompilerDirectives.transferToInterpreter();
throw InvalidArrayIndexException.create(index);
}
long offset = index * stride;
long newStride;
if (columnMajor) {
newStride = elementsPerDimension[0];
} else {
newStride = stride / elementsPerDimension[1];
}
return new MultiDimDeviceArrayView(this, 1, offset, newStride);
}

@ExportMessage
void writeArrayElement(@SuppressWarnings("unused") long index, @SuppressWarnings("unused") Object value) {
throw new IllegalStateException("attempting to write MultiDimensionArray directly");
long offset = index * stridePerDimension[0];
return new MultiDimDeviceArrayView(this, 1, offset, stridePerDimension[1]);
}

@ExportMessage
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -53,8 +54,6 @@ public final class MultiDimDeviceArrayView implements TruffleObject {
this.thisDimension = dim;
this.offset = offset;
this.stride = stride;
// System.out.printf("MultiDimDeviceArrayView(dim=%d, offset=%d, stride=%d)\n", dim, offset,
// stride);
}

public int getDimension() {
Expand Down Expand Up @@ -110,14 +109,12 @@ boolean isArrayElementInsertable(@SuppressWarnings("unused") long index) {
@ExportMessage
Object readArrayElement(long index,
@Shared("elementType") @Cached("createIdentityProfile()") ValueProfile elementTypeProfile) throws InvalidArrayIndexException {
// System.out.println("MultiDimDeviceArrayView::readArrayElement(" + index + ')');
if ((index < 0) || (index >= mdDeviceArray.getElementsInDimension(thisDimension))) {
CompilerDirectives.transferToInterpreter();
throw InvalidArrayIndexException.create(index);
}
if ((thisDimension + 1) == mdDeviceArray.getNumberDimensions()) {
long flatIndex = offset + index * stride;
// System.out.println("R access " + flatIndex);
switch (elementTypeProfile.profile(mdDeviceArray.getElementType())) {
case BYTE:
case CHAR:
Expand All @@ -136,12 +133,7 @@ Object readArrayElement(long index,
return null;
} else {
long off = offset + index * stride;
long newStride;
if (mdDeviceArray.isColumnMajorFormat()) {
newStride = stride * mdDeviceArray.getElementsInDimension(thisDimension);
} else {
newStride = stride / mdDeviceArray.getElementsInDimension(thisDimension + 1);
}
long newStride = mdDeviceArray.getStrideInDimension(thisDimension + 1);
return new MultiDimDeviceArrayView(mdDeviceArray, thisDimension + 1, off, newStride);
}
}
Expand All @@ -150,14 +142,12 @@ Object readArrayElement(long index,
void writeArrayElement(long index, Object value,
@CachedLibrary(limit = "3") InteropLibrary valueLibrary,
@Shared("elementType") @Cached("createIdentityProfile()") ValueProfile elementTypeProfile) throws UnsupportedTypeException, InvalidArrayIndexException {
// System.out.println("MultiDimDeviceArrayView::writeArrayElement(" + index + ')');
if ((index < 0) || (index >= mdDeviceArray.getElementsInDimension(thisDimension))) {
CompilerDirectives.transferToInterpreter();
throw InvalidArrayIndexException.create(index);
}
if ((thisDimension + 1) == mdDeviceArray.getNumberDimensions()) {
long flatIndex = offset + index * stride;
// System.out.println("W access " + flatIndex);
try {
switch (elementTypeProfile.profile(mdDeviceArray.getElementType())) {
case BYTE:
Expand All @@ -183,10 +173,10 @@ void writeArrayElement(long index, Object value,
}
} catch (UnsupportedMessageException e) {
CompilerDirectives.transferToInterpreter();
throw UnsupportedTypeException.create(new Object[]{value}, "value cannot be coerced to " +
mdDeviceArray.getElementType());
throw UnsupportedTypeException.create(new Object[]{value}, "value cannot be coerced to " + mdDeviceArray.getElementType());
}
} else {
CompilerDirectives.transferToInterpreter();
throw new IllegalStateException("tried to write non-last dimension in MultiDimDeviceArrayView");
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -35,7 +36,7 @@
*/
public class LittleEndianNativeArrayView {

private final Unsafe unsafe;
private static final Unsafe unsafe = UnsafeHelper.getUnsafe();
private final long startAddress;
private final long sizeInBytes;

Expand Down Expand Up @@ -110,7 +111,6 @@ public String toString() {
}

LittleEndianNativeArrayView(long startAddress, long sizeInBytes) {
this.unsafe = UnsafeHelper.getUnsafe();
this.startAddress = startAddress;
this.sizeInBytes = sizeInBytes;
}
Expand Down

0 comments on commit 02193bb

Please sign in to comment.