diff --git a/.scalafmt.conf b/.scalafmt.conf
new file mode 100644
index 00000000..834f2d20
--- /dev/null
+++ b/.scalafmt.conf
@@ -0,0 +1 @@
+version = 2.7.5
\ No newline at end of file
diff --git a/README.md b/README.md
index b9cb6da3..b23b6591 100755
--- a/README.md
+++ b/README.md
@@ -48,7 +48,7 @@ Build instructions:
```
% git clone https://github.com/openucx/sparkucx
% cd sparkucx
-% mvn -DskipTests clean package -Pspark-2.4
+% mvn -DskipTests clean package -Pspark-3.0
```
### Performance
diff --git a/pom.xml b/pom.xml
index b23b54ac..1e2826ad 100755
--- a/pom.xml
+++ b/pom.xml
@@ -149,7 +149,7 @@ See file LICENSE for terms.
org.openucx
jucx
- 1.11.0-rc3
+ 1.14.1
diff --git a/src/main/java/org/apache/spark/shuffle/ucx/UcxNode.java b/src/main/java/org/apache/spark/shuffle/ucx/UcxNode.java
index 625281b9..5f0f73f6 100755
--- a/src/main/java/org/apache/spark/shuffle/ucx/UcxNode.java
+++ b/src/main/java/org/apache/spark/shuffle/ucx/UcxNode.java
@@ -33,200 +33,201 @@
* Single instance class per spark process, that keeps UcpContext, memory and worker pools.
*/
public class UcxNode implements Closeable {
- // Global
- private static final Logger logger = LoggerFactory.getLogger(UcxNode.class);
- private final boolean isDriver;
- private final UcpContext context;
- private final MemoryPool memoryPool;
- private final UcpWorkerParams workerParams = new UcpWorkerParams();
- private final UcpWorker globalWorker;
- private final UcxShuffleConf conf;
- // Mapping from spark's entity of BlockManagerId to UcxEntity workerAddress.
- private static final ConcurrentHashMap workerAdresses =
- new ConcurrentHashMap<>();
- private final Thread listenerProgressThread;
- private boolean closed = false;
-
- // Driver
- private UcpListener listener;
- // Mapping from UcpEndpoint to ByteBuffer of RPC message, to introduce executor to cluster
- private static final ConcurrentHashMap rpcConnections =
- new ConcurrentHashMap<>();
- private List backwardEndpoints = new ArrayList<>();
-
- // Executor
- private UcpEndpoint globalDriverEndpoint;
- // Keep track of allocated workers to correctly close them.
- private static final Set allocatedWorkers = ConcurrentHashMap.newKeySet();
- private final ThreadLocal threadLocalWorker;
-
- public UcxNode(UcxShuffleConf conf, boolean isDriver) {
- this.conf = conf;
- this.isDriver = isDriver;
- UcpParams params = new UcpParams().requestTagFeature()
- .requestRmaFeature().requestWakeupFeature()
- .setMtWorkersShared(true);
- context = new UcpContext(params);
- memoryPool = new MemoryPool(context, conf);
- globalWorker = context.newWorker(workerParams);
- InetSocketAddress driverAddress = new InetSocketAddress(conf.driverHost(), conf.driverPort());
-
- if (isDriver) {
- startDriver(driverAddress);
- } else {
- startExecutor(driverAddress);
+ // Global
+ private static final Logger logger = LoggerFactory.getLogger(UcxNode.class);
+ private final boolean isDriver;
+ private final UcpContext context;
+ private final MemoryPool memoryPool;
+ private final UcpWorkerParams workerParams = new UcpWorkerParams();
+ private final UcpWorker globalWorker;
+ private final UcxShuffleConf conf;
+ // Mapping from spark's entity of BlockManagerId to UcxEntity workerAddress.
+ private static final ConcurrentHashMap workerAdresses =
+ new ConcurrentHashMap<>();
+ private final Thread listenerProgressThread;
+ private boolean closed = false;
+
+ // Driver
+ private UcpListener listener;
+ // Mapping from UcpEndpoint to ByteBuffer of RPC message, to introduce executor to cluster
+ private static final ConcurrentHashMap rpcConnections =
+ new ConcurrentHashMap<>();
+ private List backwardEndpoints = new ArrayList<>();
+
+ // Executor
+ private UcpEndpoint globalDriverEndpoint;
+ // Keep track of allocated workers to correctly close them.
+ private static final Set allocatedWorkers = ConcurrentHashMap.newKeySet();
+ private final ThreadLocal threadLocalWorker;
+
+ public UcxNode(UcxShuffleConf conf, boolean isDriver) {
+ this.conf = conf;
+ this.isDriver = isDriver;
+ UcpParams params = new UcpParams().requestTagFeature()
+ .requestRmaFeature().requestWakeupFeature()
+ .setMtWorkersShared(true);
+ context = new UcpContext(params);
+ memoryPool = new MemoryPool(context, conf);
+ globalWorker = context.newWorker(workerParams);
+ InetSocketAddress driverAddress = new InetSocketAddress(conf.driverHost(), conf.driverPort());
+
+ if (isDriver) {
+ startDriver(driverAddress);
+ } else {
+ startExecutor(driverAddress);
+ }
+
+ // Global listener thread, that keeps lazy progress for connection establishment
+ listenerProgressThread = new UcxListenerThread(this, isDriver);
+ listenerProgressThread.start();
+
+ if (!isDriver) {
+ memoryPool.preAlocate();
+ }
+
+ threadLocalWorker = ThreadLocal.withInitial(() -> {
+ UcpWorker localWorker = context.newWorker(workerParams);
+ UcxWorkerWrapper result = new UcxWorkerWrapper(localWorker,
+ conf, allocatedWorkers.size());
+ if (result.id() > conf.coresPerProcess()) {
+ logger.warn("Thread: {} - creates new worker {} > numCores",
+ Thread.currentThread().getId(), result.id());
+ }
+ allocatedWorkers.add(result);
+ return result;
+ });
+ }
+
+ private void startDriver(InetSocketAddress driverAddress) {
+ // 1. Start listener on a driver and accept RPC messages from executors with their
+ // worker addresses
+ UcpListenerParams listenerParams = new UcpListenerParams().setSockAddr(driverAddress)
+ .setConnectionHandler(ucpConnectionRequest ->
+ backwardEndpoints.add(globalWorker.newEndpoint(new UcpEndpointParams()
+ .setConnectionRequest(ucpConnectionRequest))));
+ listener = globalWorker.newListener(listenerParams);
+ logger.info("Started UcxNode on {}", driverAddress);
}
- // Global listener thread, that keeps lazy progress for connection establishment
- listenerProgressThread = new UcxListenerThread(this, isDriver);
- listenerProgressThread.start();
+ /**
+ * Allocates ByteBuffer from memoryPool and serializes there workerAddress,
+ * followed by BlockManagerID
+ *
+ * @return RegisteredMemory that holds metadata buffer.
+ */
+ private RegisteredMemory buildMetadataBuffer() {
+ BlockManagerId blockManagerId = SparkEnv.get().blockManager().blockManagerId();
+ ByteBuffer workerAddresses = globalWorker.getAddress();
+
+ RegisteredMemory metadataMemory = memoryPool.get(conf.metadataRPCBufferSize());
+ ByteBuffer metadataBuffer = metadataMemory.getBuffer();
+ metadataBuffer.putInt(workerAddresses.capacity());
+ metadataBuffer.put(workerAddresses);
+ try {
+ SerializableBlockManagerID.serializeBlockManagerID(blockManagerId, metadataBuffer);
+ } catch (IOException e) {
+ String errorMsg = String.format("Failed to serialize %s: %s", blockManagerId,
+ e.getMessage());
+ throw new UcxException(errorMsg);
+ }
+ metadataBuffer.clear();
+ return metadataMemory;
+ }
- if (!isDriver) {
- memoryPool.preAlocate();
+ private void startExecutor(InetSocketAddress driverAddress) {
+ // 1. Executor: connect to driver using sockaddr
+ // and send it's worker address followed by BlockManagerID.
+ globalDriverEndpoint = globalWorker.newEndpoint(
+ new UcpEndpointParams().setSocketAddress(driverAddress).setPeerErrorHandlingMode()
+ );
+
+ RegisteredMemory metadataMemory = buildMetadataBuffer();
+ // TODO: send using stream API when it would be available in jucx.
+ globalDriverEndpoint.sendTaggedNonBlocking(metadataMemory.getBuffer(), new UcxCallback() {
+ @Override
+ public void onSuccess(UcpRequest request) {
+ memoryPool.put(metadataMemory);
+ }
+ });
}
- threadLocalWorker = ThreadLocal.withInitial(() -> {
- UcpWorker localWorker = context.newWorker(workerParams);
- UcxWorkerWrapper result = new UcxWorkerWrapper(localWorker,
- conf, allocatedWorkers.size());
- if (result.id() > conf.coresPerProcess()) {
- logger.warn("Thread: {} - creates new worker {} > numCores",
- Thread.currentThread().getId(), result.id());
- }
- allocatedWorkers.add(result);
- return result;
- });
- }
-
- private void startDriver(InetSocketAddress driverAddress) {
- // 1. Start listener on a driver and accept RPC messages from executors with their
- // worker addresses
- UcpListenerParams listenerParams = new UcpListenerParams().setSockAddr(driverAddress)
- .setConnectionHandler(ucpConnectionRequest ->
- backwardEndpoints.add(globalWorker.newEndpoint(new UcpEndpointParams()
- .setConnectionRequest(ucpConnectionRequest))));
- listener = globalWorker.newListener(listenerParams);
- logger.info("Started UcxNode on {}", driverAddress);
- }
-
- /**
- * Allocates ByteBuffer from memoryPool and serializes there workerAddress,
- * followed by BlockManagerID
- * @return RegisteredMemory that holds metadata buffer.
- */
- private RegisteredMemory buildMetadataBuffer() {
- BlockManagerId blockManagerId = SparkEnv.get().blockManager().blockManagerId();
- ByteBuffer workerAddresses = globalWorker.getAddress();
-
- RegisteredMemory metadataMemory = memoryPool.get(conf.metadataRPCBufferSize());
- ByteBuffer metadataBuffer = metadataMemory.getBuffer();
- metadataBuffer.putInt(workerAddresses.capacity());
- metadataBuffer.put(workerAddresses);
- try {
- SerializableBlockManagerID.serializeBlockManagerID(blockManagerId, metadataBuffer);
- } catch (IOException e) {
- String errorMsg = String.format("Failed to serialize %s: %s", blockManagerId,
- e.getMessage());
- throw new UcxException(errorMsg);
+ public UcxShuffleConf getConf() {
+ return conf;
}
- metadataBuffer.clear();
- return metadataMemory;
- }
-
- private void startExecutor(InetSocketAddress driverAddress) {
- // 1. Executor: connect to driver using sockaddr
- // and send it's worker address followed by BlockManagerID.
- globalDriverEndpoint = globalWorker.newEndpoint(
- new UcpEndpointParams().setSocketAddress(driverAddress).setPeerErrorHandlingMode()
- );
-
- RegisteredMemory metadataMemory = buildMetadataBuffer();
- // TODO: send using stream API when it would be available in jucx.
- globalDriverEndpoint.sendTaggedNonBlocking(metadataMemory.getBuffer(), new UcxCallback() {
- @Override
- public void onSuccess(UcpRequest request) {
- memoryPool.put(metadataMemory);
- }
- });
- }
-
- public UcxShuffleConf getConf() {
- return conf;
- }
-
- public UcpWorker getGlobalWorker() {
- return globalWorker;
- }
-
- public MemoryPool getMemoryPool() {
- return memoryPool;
- }
-
- public UcpContext getContext() {
- return context;
- }
-
- /**
- * Get or initialize worker for current thread
- */
- public UcxWorkerWrapper getThreadLocalWorker() {
- return threadLocalWorker.get();
- }
-
- public static ConcurrentMap getWorkerAddresses() {
- return workerAdresses;
- }
-
- public static ConcurrentMap getRpcConnections() {
- return rpcConnections;
- }
-
- private void stopDriver() {
- if (listener != null) {
- listener.close();
- listener = null;
+
+ public UcpWorker getGlobalWorker() {
+ return globalWorker;
}
- rpcConnections.keySet().forEach(UcpEndpoint::close);
- rpcConnections.clear();
- backwardEndpoints.forEach(UcpEndpoint::close);
- backwardEndpoints.clear();
- }
-
- private void stopExecutor() {
- if (globalDriverEndpoint != null) {
- globalDriverEndpoint.close();
- globalDriverEndpoint = null;
+
+ public MemoryPool getMemoryPool() {
+ return memoryPool;
}
- allocatedWorkers.forEach(UcxWorkerWrapper::close);
- allocatedWorkers.clear();
- }
-
- @Override
- public void close() {
- threadLocalWorker.remove();
- synchronized (this) {
- if (!closed) {
- logger.info("Stopping UcxNode");
- listenerProgressThread.interrupt();
- globalWorker.signal();
- try {
- listenerProgressThread.join();
- if (isDriver) {
- stopDriver();
- } else {
- stopExecutor();
- }
- memoryPool.close();
- globalWorker.close();
- context.close();
- closed = true;
- } catch (InterruptedException e) {
- logger.error(e.getMessage());
- Thread.currentThread().interrupt();
- } catch (Exception ex) {
- logger.warn(ex.getLocalizedMessage());
+
+ public UcpContext getContext() {
+ return context;
+ }
+
+ /**
+ * Get or initialize worker for current thread
+ */
+ public UcxWorkerWrapper getThreadLocalWorker() {
+ return threadLocalWorker.get();
+ }
+
+ public static ConcurrentMap getWorkerAddresses() {
+ return workerAdresses;
+ }
+
+ public static ConcurrentMap getRpcConnections() {
+ return rpcConnections;
+ }
+
+ private void stopDriver() {
+ if (listener != null) {
+ listener.close();
+ listener = null;
+ }
+ rpcConnections.keySet().forEach(UcpEndpoint::close);
+ rpcConnections.clear();
+ backwardEndpoints.forEach(UcpEndpoint::close);
+ backwardEndpoints.clear();
+ }
+
+ private void stopExecutor() {
+ if (globalDriverEndpoint != null) {
+ globalDriverEndpoint.close();
+ globalDriverEndpoint = null;
+ }
+ allocatedWorkers.forEach(UcxWorkerWrapper::close);
+ allocatedWorkers.clear();
+ }
+
+ @Override
+ public void close() {
+ threadLocalWorker.remove();
+ synchronized (this) {
+ if (!closed) {
+ logger.info("Stopping UcxNode");
+ listenerProgressThread.interrupt();
+ globalWorker.signal();
+ try {
+ listenerProgressThread.join();
+ if (isDriver) {
+ stopDriver();
+ } else {
+ stopExecutor();
+ }
+ memoryPool.close();
+ globalWorker.close();
+ context.close();
+ closed = true;
+ } catch (InterruptedException e) {
+ logger.error(e.getMessage());
+ Thread.currentThread().interrupt();
+ } catch (Exception ex) {
+ logger.warn(ex.getLocalizedMessage());
+ }
+ }
}
- }
}
- }
}
diff --git a/src/main/java/org/apache/spark/shuffle/ucx/UnsafeUtils.java b/src/main/java/org/apache/spark/shuffle/ucx/UnsafeUtils.java
index 9e9e0843..adbfc8c8 100755
--- a/src/main/java/org/apache/spark/shuffle/ucx/UnsafeUtils.java
+++ b/src/main/java/org/apache/spark/shuffle/ucx/UnsafeUtils.java
@@ -20,58 +20,59 @@
* Java's native mmap functionality, that allows to mmap files > 2GB.
*/
public class UnsafeUtils {
- private static final Method mmap;
- private static final Method unmmap;
- private static final Logger logger = LoggerFactory.getLogger(UnsafeUtils.class);
+ private static final Method mmap;
+ private static final Method unmmap;
+ private static final Logger logger = LoggerFactory.getLogger(UnsafeUtils.class);
- private static final Constructor> directBufferConstructor;
+ private static final Constructor> directBufferConstructor;
- public static final int LONG_SIZE = 8;
- public static final int INT_SIZE = 4;
+ public static final int LONG_SIZE = 8;
+ public static final int INT_SIZE = 4;
- static {
- try {
- mmap = FileChannelImpl.class.getDeclaredMethod("map0", int.class, long.class, long.class);
- mmap.setAccessible(true);
- unmmap = FileChannelImpl.class.getDeclaredMethod("unmap0", long.class, long.class);
- unmmap.setAccessible(true);
- Class> classDirectByteBuffer = Class.forName("java.nio.DirectByteBuffer");
- directBufferConstructor = classDirectByteBuffer.getDeclaredConstructor(long.class, int.class);
- directBufferConstructor.setAccessible(true);
- } catch (Exception e) {
- throw new RuntimeException(e);
+ static {
+ try {
+ mmap = FileChannelImpl.class.getDeclaredMethod("map0", int.class, long.class, long.class);
+ mmap.setAccessible(true);
+ unmmap = FileChannelImpl.class.getDeclaredMethod("unmap0", long.class, long.class);
+ unmmap.setAccessible(true);
+ Class> classDirectByteBuffer = Class.forName("java.nio.DirectByteBuffer");
+ directBufferConstructor = classDirectByteBuffer.getDeclaredConstructor(long.class, int.class);
+ directBufferConstructor.setAccessible(true);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
}
- }
- private UnsafeUtils() {}
+ private UnsafeUtils() {
+ }
- public static long mmap(FileChannel fileChannel, long offset, long length) {
- long result;
- try {
- result = (long)mmap.invoke(fileChannel, 1, offset, length);
- } catch (Exception e) {
- logger.error("MMap({}, {}) failed: {}", offset, length, e.getMessage());
- throw new UcxException(e.getMessage());
+ public static long mmap(FileChannel fileChannel, long offset, long length) {
+ long result;
+ try {
+ result = (long) mmap.invoke(fileChannel, 1, offset, length);
+ } catch (Exception e) {
+ logger.error("MMap({}, {}) failed: {}", offset, length, e.getMessage());
+ throw new UcxException(e.getMessage());
+ }
+ return result;
}
- return result;
- }
- public static void munmap(long address, long length) {
- try {
- unmmap.invoke(null, address, length);
- } catch (IllegalAccessException | InvocationTargetException e) {
- logger.error(e.getMessage());
+ public static void munmap(long address, long length) {
+ try {
+ unmmap.invoke(null, address, length);
+ } catch (IllegalAccessException | InvocationTargetException e) {
+ logger.error(e.getMessage());
+ }
}
- }
- public static ByteBuffer getByteBuffer(long address, int length) throws IOException {
- try {
- return (ByteBuffer)directBufferConstructor.newInstance(address, length);
- } catch (InvocationTargetException ex) {
- throw new IOException("java.nio.DirectByteBuffer: " +
- "InvocationTargetException: " + ex.getTargetException());
- } catch (Exception e) {
- throw new IOException("java.nio.DirectByteBuffer exception: " + e.getMessage());
+ public static ByteBuffer getByteBuffer(long address, int length) throws IOException {
+ try {
+ return (ByteBuffer) directBufferConstructor.newInstance(address, length);
+ } catch (InvocationTargetException ex) {
+ throw new IOException("java.nio.DirectByteBuffer: " +
+ "InvocationTargetException: " + ex.getTargetException());
+ } catch (Exception e) {
+ throw new IOException("java.nio.DirectByteBuffer exception: " + e.getMessage());
+ }
}
- }
}
diff --git a/src/main/java/org/apache/spark/shuffle/ucx/memory/MemoryPool.java b/src/main/java/org/apache/spark/shuffle/ucx/memory/MemoryPool.java
index 200e9ee2..2d5979ec 100755
--- a/src/main/java/org/apache/spark/shuffle/ucx/memory/MemoryPool.java
+++ b/src/main/java/org/apache/spark/shuffle/ucx/memory/MemoryPool.java
@@ -25,155 +25,155 @@
* and registration during shuffle phase.
*/
public class MemoryPool implements Closeable {
- private static final Logger logger = LoggerFactory.getLogger(MemoryPool.class);
-
- @Override
- public void close() {
- for (AllocatorStack stack: allocStackMap.values()) {
- stack.close();
- logger.info("Stack of size {}. " +
- "Total requests: {}, total allocations: {}, preAllocations: {}",
- stack.length, stack.totalRequests.get(), stack.totalAlloc.get(), stack.preAllocs.get());
- }
- allocStackMap.clear();
- }
-
- private class AllocatorStack implements Closeable {
- private final AtomicInteger totalRequests = new AtomicInteger(0);
- private final AtomicInteger totalAlloc = new AtomicInteger(0);
- private final AtomicInteger preAllocs = new AtomicInteger(0);
- private final ConcurrentLinkedDeque stack = new ConcurrentLinkedDeque<>();
- private final int length;
-
- private AllocatorStack(int length) {
- this.length = length;
- }
+ private static final Logger logger = LoggerFactory.getLogger(MemoryPool.class);
- private RegisteredMemory get() {
- RegisteredMemory result = stack.pollFirst();
- if (result == null) {
- if (length < conf.minRegistrationSize()) {
- int numBuffers = conf.minRegistrationSize() / length;
- logger.debug("Allocating {} buffers of size {}", numBuffers, length);
- preallocate(numBuffers);
- result = stack.pollFirst();
- if (result == null) {
- return get();
- } else {
- result.getRefCount().incrementAndGet();
- }
- } else {
- UcpMemMapParams memMapParams = new UcpMemMapParams().setLength(length).allocate();
- UcpMemory memory = context.memoryMap(memMapParams);
- ByteBuffer buffer;
- try {
- buffer = UcxUtils.getByteBufferView(memory.getAddress(), (int)memory.getLength());
- } catch (Exception e) {
- throw new UcxException(e.getMessage());
- }
- result = new RegisteredMemory(new AtomicInteger(1), memory, buffer);
- totalAlloc.incrementAndGet();
+ @Override
+ public void close() {
+ for (AllocatorStack stack : allocStackMap.values()) {
+ stack.close();
+ logger.info("Stack of size {}. " +
+ "Total requests: {}, total allocations: {}, preAllocations: {}",
+ stack.length, stack.totalRequests.get(), stack.totalAlloc.get(), stack.preAllocs.get());
}
- } else {
- result.getRefCount().incrementAndGet();
- }
- totalRequests.incrementAndGet();
- return result;
+ allocStackMap.clear();
}
- private void put(RegisteredMemory registeredMemory) {
- registeredMemory.getRefCount().decrementAndGet();
- stack.addLast(registeredMemory);
+ private class AllocatorStack implements Closeable {
+ private final AtomicInteger totalRequests = new AtomicInteger(0);
+ private final AtomicInteger totalAlloc = new AtomicInteger(0);
+ private final AtomicInteger preAllocs = new AtomicInteger(0);
+ private final ConcurrentLinkedDeque stack = new ConcurrentLinkedDeque<>();
+ private final int length;
+
+ private AllocatorStack(int length) {
+ this.length = length;
+ }
+
+ private RegisteredMemory get() {
+ RegisteredMemory result = stack.pollFirst();
+ if (result == null) {
+ if (length < conf.minRegistrationSize()) {
+ int numBuffers = conf.minRegistrationSize() / length;
+ logger.debug("Allocating {} buffers of size {}", numBuffers, length);
+ preallocate(numBuffers);
+ result = stack.pollFirst();
+ if (result == null) {
+ return get();
+ } else {
+ result.getRefCount().incrementAndGet();
+ }
+ } else {
+ UcpMemMapParams memMapParams = new UcpMemMapParams().setLength(length).allocate();
+ UcpMemory memory = context.memoryMap(memMapParams);
+ ByteBuffer buffer;
+ try {
+ buffer = UcxUtils.getByteBufferView(memory.getAddress(), (int) memory.getLength());
+ } catch (Exception e) {
+ throw new UcxException(e.getMessage());
+ }
+ result = new RegisteredMemory(new AtomicInteger(1), memory, buffer);
+ totalAlloc.incrementAndGet();
+ }
+ } else {
+ result.getRefCount().incrementAndGet();
+ }
+ totalRequests.incrementAndGet();
+ return result;
+ }
+
+ private void put(RegisteredMemory registeredMemory) {
+ registeredMemory.getRefCount().decrementAndGet();
+ stack.addLast(registeredMemory);
+ }
+
+ private void preallocate(int numBuffers) {
+ // Platform.allocateDirectBuffer supports only 2GB of buffer.
+ // Decrease number of buffers if total size of preAllocation > 2GB.
+ if ((long) length * (long) numBuffers > Integer.MAX_VALUE) {
+ numBuffers = Integer.MAX_VALUE / length;
+ }
+
+ UcpMemMapParams memMapParams = new UcpMemMapParams().allocate().setLength(numBuffers * (long) length);
+ UcpMemory memory = context.memoryMap(memMapParams);
+ ByteBuffer buffer;
+ try {
+ buffer = UnsafeUtils.getByteBuffer(memory.getAddress(), numBuffers * length);
+ } catch (Exception ex) {
+ throw new UcxException(ex.getMessage());
+ }
+
+ AtomicInteger refCount = new AtomicInteger(numBuffers);
+ for (int i = 0; i < numBuffers; i++) {
+ buffer.position(i * length).limit(i * length + length);
+ final ByteBuffer slice = buffer.slice();
+ RegisteredMemory registeredMemory = new RegisteredMemory(refCount, memory, slice);
+ put(registeredMemory);
+ }
+ preAllocs.incrementAndGet();
+ totalAlloc.incrementAndGet();
+ }
+
+ @Override
+ public void close() {
+ while (!stack.isEmpty()) {
+ RegisteredMemory memory = stack.pollFirst();
+ if (memory != null) {
+ memory.deregisterNativeMemory();
+ }
+ }
+ }
}
- private void preallocate(int numBuffers) {
- // Platform.allocateDirectBuffer supports only 2GB of buffer.
- // Decrease number of buffers if total size of preAllocation > 2GB.
- if ((long)length * (long)numBuffers > Integer.MAX_VALUE) {
- numBuffers = Integer.MAX_VALUE / length;
- }
-
- UcpMemMapParams memMapParams = new UcpMemMapParams().allocate().setLength(numBuffers * (long)length);
- UcpMemory memory = context.memoryMap(memMapParams);
- ByteBuffer buffer;
- try {
- buffer = UnsafeUtils.getByteBuffer(memory.getAddress(), numBuffers * length);
- } catch (Exception ex) {
- throw new UcxException(ex.getMessage());
- }
-
- AtomicInteger refCount = new AtomicInteger(numBuffers);
- for (int i = 0; i < numBuffers; i++) {
- buffer.position(i * length).limit(i * length + length);
- final ByteBuffer slice = buffer.slice();
- RegisteredMemory registeredMemory = new RegisteredMemory(refCount, memory, slice);
- put(registeredMemory);
- }
- preAllocs.incrementAndGet();
- totalAlloc.incrementAndGet();
+ private final ConcurrentHashMap allocStackMap =
+ new ConcurrentHashMap<>();
+ private final UcpContext context;
+ private final UcxShuffleConf conf;
+
+ public MemoryPool(UcpContext context, UcxShuffleConf conf) {
+ this.context = context;
+ this.conf = conf;
}
- @Override
- public void close() {
- while (!stack.isEmpty()) {
- RegisteredMemory memory = stack.pollFirst();
- if (memory != null) {
- memory.deregisterNativeMemory();
+ private long roundUpToTheNextPowerOf2(long length) {
+ // Round up length to the nearest power of two, or the minimum block size
+ if (length < conf.minBufferSize()) {
+ length = conf.minBufferSize();
+ } else {
+ length--;
+ length |= length >> 1;
+ length |= length >> 2;
+ length |= length >> 4;
+ length |= length >> 8;
+ length |= length >> 16;
+ length++;
}
- }
+ return length;
+ }
+
+ public RegisteredMemory get(int size) {
+ long roundedSize = roundUpToTheNextPowerOf2(size);
+ assert roundedSize < Integer.MAX_VALUE && roundedSize > 0;
+ AllocatorStack stack =
+ allocStackMap.computeIfAbsent((int) roundedSize, AllocatorStack::new);
+ RegisteredMemory result = stack.get();
+ result.getBuffer().position(0).limit(size);
+ return result;
}
- }
-
- private final ConcurrentHashMap allocStackMap =
- new ConcurrentHashMap<>();
- private final UcpContext context;
- private final UcxShuffleConf conf;
-
- public MemoryPool(UcpContext context, UcxShuffleConf conf) {
- this.context = context;
- this.conf = conf;
- }
-
- private long roundUpToTheNextPowerOf2(long length) {
- // Round up length to the nearest power of two, or the minimum block size
- if (length < conf.minBufferSize()) {
- length = conf.minBufferSize();
- } else {
- length--;
- length |= length >> 1;
- length |= length >> 2;
- length |= length >> 4;
- length |= length >> 8;
- length |= length >> 16;
- length++;
+
+ public void put(RegisteredMemory memory) {
+ AllocatorStack allocatorStack = allocStackMap.get(memory.getBuffer().capacity());
+ if (allocatorStack != null) {
+ allocatorStack.put(memory);
+ }
}
- return length;
- }
-
- public RegisteredMemory get(int size) {
- long roundedSize = roundUpToTheNextPowerOf2(size);
- assert roundedSize < Integer.MAX_VALUE && roundedSize > 0;
- AllocatorStack stack =
- allocStackMap.computeIfAbsent((int)roundedSize, AllocatorStack::new);
- RegisteredMemory result = stack.get();
- result.getBuffer().position(0).limit(size);
- return result;
- }
-
- public void put(RegisteredMemory memory) {
- AllocatorStack allocatorStack = allocStackMap.get(memory.getBuffer().capacity());
- if (allocatorStack != null) {
- allocatorStack.put(memory);
+
+ public void preAlocate() {
+ conf.preallocateBuffersMap().forEach((size, numBuffers) -> {
+ logger.debug("Pre allocating {} buffers of size {}", numBuffers, size);
+ AllocatorStack stack = new AllocatorStack(size);
+ allocStackMap.put(size, stack);
+ stack.preallocate(numBuffers);
+ });
}
- }
-
- public void preAlocate() {
- conf.preallocateBuffersMap().forEach((size, numBuffers) -> {
- logger.debug("Pre allocating {} buffers of size {}", numBuffers, size);
- AllocatorStack stack = new AllocatorStack(size);
- allocStackMap.put(size, stack);
- stack.preallocate(numBuffers);
- });
- }
}
diff --git a/src/main/java/org/apache/spark/shuffle/ucx/memory/RegisteredMemory.java b/src/main/java/org/apache/spark/shuffle/ucx/memory/RegisteredMemory.java
index 47be67ef..9f8fbd5c 100755
--- a/src/main/java/org/apache/spark/shuffle/ucx/memory/RegisteredMemory.java
+++ b/src/main/java/org/apache/spark/shuffle/ucx/memory/RegisteredMemory.java
@@ -12,32 +12,32 @@
* Keeps track on reference count to memory region.
*/
public class RegisteredMemory {
- private static final Logger logger = LoggerFactory.getLogger(RegisteredMemory.class);
+ private static final Logger logger = LoggerFactory.getLogger(RegisteredMemory.class);
- private final AtomicInteger refcount;
- private final UcpMemory memory;
- private final ByteBuffer buffer;
+ private final AtomicInteger refcount;
+ private final UcpMemory memory;
+ private final ByteBuffer buffer;
- RegisteredMemory(AtomicInteger refcount, UcpMemory memory, ByteBuffer buffer) {
- this.refcount = refcount;
- this.memory = memory;
- this.buffer = buffer;
- }
-
- public ByteBuffer getBuffer() {
- return buffer;
- }
+ RegisteredMemory(AtomicInteger refcount, UcpMemory memory, ByteBuffer buffer) {
+ this.refcount = refcount;
+ this.memory = memory;
+ this.buffer = buffer;
+ }
- AtomicInteger getRefCount() {
- return refcount;
- }
+ public ByteBuffer getBuffer() {
+ return buffer;
+ }
- void deregisterNativeMemory() {
- if (refcount.get() != 0) {
- logger.warn("De-registering memory of size {} that has active references.", buffer.capacity());
+ AtomicInteger getRefCount() {
+ return refcount;
}
- if (memory != null && memory.getNativeId() != null) {
- memory.deregister();
+
+ void deregisterNativeMemory() {
+ if (refcount.get() != 0) {
+ logger.warn("De-registering memory of size {} that has active references.", buffer.capacity());
+ }
+ if (memory != null && memory.getNativeId() != null) {
+ memory.deregister();
+ }
}
- }
}
diff --git a/src/main/java/org/apache/spark/shuffle/ucx/reducer/OnBlocksFetchCallback.java b/src/main/java/org/apache/spark/shuffle/ucx/reducer/OnBlocksFetchCallback.java
index 57f2fe83..5938e81e 100755
--- a/src/main/java/org/apache/spark/shuffle/ucx/reducer/OnBlocksFetchCallback.java
+++ b/src/main/java/org/apache/spark/shuffle/ucx/reducer/OnBlocksFetchCallback.java
@@ -20,39 +20,39 @@
* Notifies Spark's shuffleFetchIterator on block fetch completion.
*/
public class OnBlocksFetchCallback extends ReducerCallback {
- protected RegisteredMemory blocksMemory;
- protected int[] sizes;
+ protected RegisteredMemory blocksMemory;
+ protected int[] sizes;
- public OnBlocksFetchCallback(ReducerCallback callback, RegisteredMemory blocksMemory, int[] sizes) {
- super(callback);
- this.blocksMemory = blocksMemory;
- this.sizes = sizes;
- }
+ public OnBlocksFetchCallback(ReducerCallback callback, RegisteredMemory blocksMemory, int[] sizes) {
+ super(callback);
+ this.blocksMemory = blocksMemory;
+ this.sizes = sizes;
+ }
- @Override
- public void onSuccess(UcpRequest request) {
- int position = 0;
- AtomicInteger refCount = new AtomicInteger(blockIds.length);
- for (int i = 0; i < blockIds.length; i++) {
- BlockId block = blockIds[i];
- // Blocks are fetched to contiguous buffer.
- // |----block1---||---block2---||---block3---|
- // Slice each block to avoid buffer copy.
- blocksMemory.getBuffer().position(position).limit(position + sizes[i]);
- ByteBuffer blockBuffer = blocksMemory.getBuffer().slice();
- position += sizes[i];
- // Pass block to Spark's ShuffleFetchIterator.
- listener.onBlockFetchSuccess(block.name(), new NioManagedBuffer(blockBuffer) {
- @Override
- public ManagedBuffer release() {
- if (refCount.decrementAndGet() == 0) {
- mempool.put(blocksMemory);
- }
- return this;
+ @Override
+ public void onSuccess(UcpRequest request) {
+ int position = 0;
+ AtomicInteger refCount = new AtomicInteger(blockIds.length);
+ for (int i = 0; i < blockIds.length; i++) {
+ BlockId block = blockIds[i];
+ // Blocks are fetched to contiguous buffer.
+ // |----block1---||---block2---||---block3---|
+ // Slice each block to avoid buffer copy.
+ blocksMemory.getBuffer().position(position).limit(position + sizes[i]);
+ ByteBuffer blockBuffer = blocksMemory.getBuffer().slice();
+ position += sizes[i];
+ // Pass block to Spark's ShuffleFetchIterator.
+ listener.onBlockFetchSuccess(block.name(), new NioManagedBuffer(blockBuffer) {
+ @Override
+ public ManagedBuffer release() {
+ if (refCount.decrementAndGet() == 0) {
+ mempool.put(blocksMemory);
+ }
+ return this;
+ }
+ });
}
- });
+ logger.info("Endpoint {} fetched {} blocks of total size {} in {}ms", endpoint.getNativeId(), blockIds.length,
+ Utils.bytesToString(position), System.currentTimeMillis() - startTime);
}
- logger.info("Endpoint {} fetched {} blocks of total size {} in {}ms", endpoint.getNativeId(), blockIds.length,
- Utils.bytesToString(position), System.currentTimeMillis() - startTime);
- }
}
diff --git a/src/main/java/org/apache/spark/shuffle/ucx/reducer/ReducerCallback.java b/src/main/java/org/apache/spark/shuffle/ucx/reducer/ReducerCallback.java
index 612ce9c3..1d722539 100755
--- a/src/main/java/org/apache/spark/shuffle/ucx/reducer/ReducerCallback.java
+++ b/src/main/java/org/apache/spark/shuffle/ucx/reducer/ReducerCallback.java
@@ -18,25 +18,25 @@
* Common data needed for offset fetch callback and subsequent block fetch callback.
*/
public abstract class ReducerCallback extends UcxCallback {
- protected MemoryPool mempool;
- protected BlockId[] blockIds;
- protected UcpEndpoint endpoint;
- protected BlockFetchingListener listener;
- protected static final Logger logger = LoggerFactory.getLogger(ReducerCallback.class);
- protected long startTime = System.currentTimeMillis();
+ protected MemoryPool mempool;
+ protected BlockId[] blockIds;
+ protected UcpEndpoint endpoint;
+ protected BlockFetchingListener listener;
+ protected static final Logger logger = LoggerFactory.getLogger(ReducerCallback.class);
+ protected long startTime = System.currentTimeMillis();
- public ReducerCallback(BlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener) {
- this.mempool = ((CommonUcxShuffleManager)SparkEnv.get().shuffleManager()).ucxNode().getMemoryPool();
- this.blockIds = blockIds;
- this.endpoint = endpoint;
- this.listener = listener;
- }
+ public ReducerCallback(BlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener) {
+ this.mempool = ((CommonUcxShuffleManager) SparkEnv.get().shuffleManager()).ucxNode().getMemoryPool();
+ this.blockIds = blockIds;
+ this.endpoint = endpoint;
+ this.listener = listener;
+ }
- public ReducerCallback(ReducerCallback callback) {
- this.blockIds = callback.blockIds;
- this.endpoint = callback.endpoint;
- this.listener = callback.listener;
- this.mempool = callback.mempool;
- this.startTime = callback.startTime;
- }
+ public ReducerCallback(ReducerCallback callback) {
+ this.blockIds = callback.blockIds;
+ this.endpoint = callback.endpoint;
+ this.listener = callback.listener;
+ this.mempool = callback.mempool;
+ this.startTime = callback.startTime;
+ }
}
diff --git a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_1/OnOffsetsFetchCallback.java b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_1/OnOffsetsFetchCallback.java
index 5f95caf6..93381610 100755
--- a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_1/OnOffsetsFetchCallback.java
+++ b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_1/OnOffsetsFetchCallback.java
@@ -22,49 +22,49 @@
* Callback, called when got all offsets for blocks
*/
public class OnOffsetsFetchCallback extends ReducerCallback {
- private final RegisteredMemory offsetMemory;
- private final long[] dataAddresses;
- private Map dataRkeysCache;
+ private final RegisteredMemory offsetMemory;
+ private final long[] dataAddresses;
+ private Map dataRkeysCache;
- public OnOffsetsFetchCallback(ShuffleBlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener,
- RegisteredMemory offsetMemory, long[] dataAddresses,
- Map dataRkeysCache) {
- super(blockIds, endpoint, listener);
- this.offsetMemory = offsetMemory;
- this.dataAddresses = dataAddresses;
- this.dataRkeysCache = dataRkeysCache;
- }
-
- @Override
- public void onSuccess(UcpRequest request) {
- ByteBuffer resultOffset = offsetMemory.getBuffer();
- long totalSize = 0;
- int[] sizes = new int[blockIds.length];
- int offsetSize = UnsafeUtils.LONG_SIZE;
- for (int i = 0; i < blockIds.length; i++) {
- // Blocks in metadata buffer are in form | blockOffsetStart | blockOffsetEnd |
- long blockOffset = resultOffset.getLong(i * 2 * offsetSize);
- long blockLength = resultOffset.getLong(i * 2 * offsetSize + offsetSize) - blockOffset;
- assert (blockLength > 0) && (blockLength <= Integer.MAX_VALUE);
- sizes[i] = (int) blockLength;
- totalSize += blockLength;
- dataAddresses[i] += blockOffset;
+ public OnOffsetsFetchCallback(ShuffleBlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener,
+ RegisteredMemory offsetMemory, long[] dataAddresses,
+ Map dataRkeysCache) {
+ super(blockIds, endpoint, listener);
+ this.offsetMemory = offsetMemory;
+ this.dataAddresses = dataAddresses;
+ this.dataRkeysCache = dataRkeysCache;
}
- assert (totalSize > 0) && (totalSize < Integer.MAX_VALUE);
- mempool.put(offsetMemory);
- RegisteredMemory blocksMemory = mempool.get((int) totalSize);
+ @Override
+ public void onSuccess(UcpRequest request) {
+ ByteBuffer resultOffset = offsetMemory.getBuffer();
+ long totalSize = 0;
+ int[] sizes = new int[blockIds.length];
+ int offsetSize = UnsafeUtils.LONG_SIZE;
+ for (int i = 0; i < blockIds.length; i++) {
+ // Blocks in metadata buffer are in form | blockOffsetStart | blockOffsetEnd |
+ long blockOffset = resultOffset.getLong(i * 2 * offsetSize);
+ long blockLength = resultOffset.getLong(i * 2 * offsetSize + offsetSize) - blockOffset;
+ assert (blockLength > 0) && (blockLength <= Integer.MAX_VALUE);
+ sizes[i] = (int) blockLength;
+ totalSize += blockLength;
+ dataAddresses[i] += blockOffset;
+ }
- long offset = 0;
- // Submits N fetch blocks requests
- for (int i = 0; i < blockIds.length; i++) {
- endpoint.getNonBlockingImplicit(dataAddresses[i], dataRkeysCache.get(((ShuffleBlockId)blockIds[i]).mapId()),
- UcxUtils.getAddress(blocksMemory.getBuffer()) + offset, sizes[i]);
- offset += sizes[i];
- }
+ assert (totalSize > 0) && (totalSize < Integer.MAX_VALUE);
+ mempool.put(offsetMemory);
+ RegisteredMemory blocksMemory = mempool.get((int) totalSize);
- // Process blocks when all fetched.
- // Flush guarantees that callback would invoke when all fetch requests will completed.
- endpoint.flushNonBlocking(new OnBlocksFetchCallback(this, blocksMemory, sizes));
- }
+ long offset = 0;
+ // Submits N fetch blocks requests
+ for (int i = 0; i < blockIds.length; i++) {
+ endpoint.getNonBlockingImplicit(dataAddresses[i], dataRkeysCache.get(((ShuffleBlockId) blockIds[i]).mapId()),
+ UcxUtils.getAddress(blocksMemory.getBuffer()) + offset, sizes[i]);
+ offset += sizes[i];
+ }
+
+ // Process blocks when all fetched.
+ // Flush guarantees that callback would invoke when all fetch requests will completed.
+ endpoint.flushNonBlocking(new OnBlocksFetchCallback(this, blocksMemory, sizes));
+ }
}
diff --git a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_1/UcxShuffleClient.java b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_1/UcxShuffleClient.java
index 0595bfc8..4d577b75 100755
--- a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_1/UcxShuffleClient.java
+++ b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_1/UcxShuffleClient.java
@@ -28,84 +28,83 @@
import java.util.HashMap;
public class UcxShuffleClient extends ShuffleClient {
- private final MemoryPool mempool;
- private static final Logger logger = LoggerFactory.getLogger(UcxShuffleClient.class);
- private final UcxShuffleManager ucxShuffleManager;
- private final TempShuffleReadMetrics shuffleReadMetrics;
- private final UcxWorkerWrapper workerWrapper;
- final HashMap offsetRkeysCache = new HashMap<>();
- final HashMap dataRkeysCache = new HashMap<>();
-
- public UcxShuffleClient(TempShuffleReadMetrics shuffleReadMetrics,
- UcxWorkerWrapper workerWrapper) {
- this.ucxShuffleManager = (UcxShuffleManager) SparkEnv.get().shuffleManager();
- this.mempool = ucxShuffleManager.ucxNode().getMemoryPool();
- this.shuffleReadMetrics = shuffleReadMetrics;
- this.workerWrapper = workerWrapper;
- }
-
- /**
- * Submits n non blocking fetch offsets to get needed offsets for n blocks.
- */
- private void submitFetchOffsets(UcpEndpoint endpoint, ShuffleBlockId[] blockIds,
- long[] dataAddresses, RegisteredMemory offsetMemory) {
- DriverMetadata driverMetadata = workerWrapper.fetchDriverMetadataBuffer(blockIds[0].shuffleId());
- for (int i = 0; i < blockIds.length; i++) {
- ShuffleBlockId blockId = blockIds[i];
-
- long offsetAddress = driverMetadata.offsetAddress(blockId.mapId());
- dataAddresses[i] = driverMetadata.dataAddress(blockId.mapId());
-
- offsetRkeysCache.computeIfAbsent(blockId.mapId(), mapId ->
- endpoint.unpackRemoteKey(driverMetadata.offsetRkey(blockId.mapId())));
-
- dataRkeysCache.computeIfAbsent(blockId.mapId(), mapId ->
- endpoint.unpackRemoteKey(driverMetadata.dataRkey(blockId.mapId())));
-
- endpoint.getNonBlockingImplicit(
- offsetAddress + blockId.reduceId() * UnsafeUtils.LONG_SIZE,
- offsetRkeysCache.get(blockId.mapId()),
- UcxUtils.getAddress(offsetMemory.getBuffer()) + (i * 2L * UnsafeUtils.LONG_SIZE),
- 2L * UnsafeUtils.LONG_SIZE);
+ private final MemoryPool mempool;
+ private static final Logger logger = LoggerFactory.getLogger(UcxShuffleClient.class);
+ private final UcxShuffleManager ucxShuffleManager;
+ private final TempShuffleReadMetrics shuffleReadMetrics;
+ private final UcxWorkerWrapper workerWrapper;
+ final HashMap offsetRkeysCache = new HashMap<>();
+ final HashMap dataRkeysCache = new HashMap<>();
+
+ public UcxShuffleClient(TempShuffleReadMetrics shuffleReadMetrics,
+ UcxWorkerWrapper workerWrapper) {
+ this.ucxShuffleManager = (UcxShuffleManager) SparkEnv.get().shuffleManager();
+ this.mempool = ucxShuffleManager.ucxNode().getMemoryPool();
+ this.shuffleReadMetrics = shuffleReadMetrics;
+ this.workerWrapper = workerWrapper;
+ }
+
+ /**
+ * Submits n non blocking fetch offsets to get needed offsets for n blocks.
+ */
+ private void submitFetchOffsets(UcpEndpoint endpoint, ShuffleBlockId[] blockIds,
+ long[] dataAddresses, RegisteredMemory offsetMemory) {
+ DriverMetadata driverMetadata = workerWrapper.fetchDriverMetadataBuffer(blockIds[0].shuffleId());
+ for (int i = 0; i < blockIds.length; i++) {
+ ShuffleBlockId blockId = blockIds[i];
+
+ long offsetAddress = driverMetadata.offsetAddress(blockId.mapId());
+ dataAddresses[i] = driverMetadata.dataAddress(blockId.mapId());
+
+ offsetRkeysCache.computeIfAbsent(blockId.mapId(), mapId ->
+ endpoint.unpackRemoteKey(driverMetadata.offsetRkey(blockId.mapId())));
+
+ dataRkeysCache.computeIfAbsent(blockId.mapId(), mapId ->
+ endpoint.unpackRemoteKey(driverMetadata.dataRkey(blockId.mapId())));
+
+ endpoint.getNonBlockingImplicit(
+ offsetAddress + blockId.reduceId() * UnsafeUtils.LONG_SIZE,
+ offsetRkeysCache.get(blockId.mapId()),
+ UcxUtils.getAddress(offsetMemory.getBuffer()) + (i * 2L * UnsafeUtils.LONG_SIZE),
+ 2L * UnsafeUtils.LONG_SIZE);
+ }
+ }
+
+ /**
+ * Reducer entry point. Fetches remote blocks, using 2 ucp_get calls.
+ * This method is inside ShuffleFetchIterator's for loop over hosts.
+ * First fetches block offset from index file, and then fetches block itself.
+ */
+ @Override
+ public void fetchBlocks(String host, int port, String execId,
+ String[] blockIds, BlockFetchingListener listener) {
+ long startTime = System.currentTimeMillis();
+
+ BlockManagerId blockManagerId = BlockManagerId.apply(execId, host, port, Option.empty());
+ UcpEndpoint endpoint = workerWrapper.getConnection(blockManagerId);
+
+ long[] dataAddresses = new long[blockIds.length];
+
+ // Need to fetch 2 long offsets current block + next block to calculate exact block size.
+ RegisteredMemory offsetMemory = mempool.get(2 * UnsafeUtils.LONG_SIZE * blockIds.length);
+
+ ShuffleBlockId[] shuffleBlockIds = Arrays.stream(blockIds)
+ .map(blockId -> (ShuffleBlockId) BlockId.apply(blockId)).toArray(ShuffleBlockId[]::new);
+
+ // Submits N implicit get requests without callback
+ submitFetchOffsets(endpoint, shuffleBlockIds, dataAddresses, offsetMemory);
+
+ // flush guarantees that all that requests completes when callback is called.
+ endpoint.flushNonBlocking(
+ new OnOffsetsFetchCallback(shuffleBlockIds, endpoint, listener, offsetMemory,
+ dataAddresses, dataRkeysCache));
+ shuffleReadMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime);
+ }
+
+ @Override
+ public void close() {
+ offsetRkeysCache.values().forEach(UcpRemoteKey::close);
+ dataRkeysCache.values().forEach(UcpRemoteKey::close);
+ logger.info("Shuffle read metrics, fetch wait time: {}ms", shuffleReadMetrics.fetchWaitTime());
}
- }
-
- /**
- * Reducer entry point. Fetches remote blocks, using 2 ucp_get calls.
- * This method is inside ShuffleFetchIterator's for loop over hosts.
- * First fetches block offset from index file, and then fetches block itself.
- */
- @Override
- public void fetchBlocks(String host, int port, String execId,
- String[] blockIds, BlockFetchingListener listener) {
- long startTime = System.currentTimeMillis();
-
- BlockManagerId blockManagerId = BlockManagerId.apply(execId, host, port, Option.empty());
- UcpEndpoint endpoint = workerWrapper.getConnection(blockManagerId);
-
- long[] dataAddresses = new long[blockIds.length];
-
- // Need to fetch 2 long offsets current block + next block to calculate exact block size.
- RegisteredMemory offsetMemory = mempool.get(2 * UnsafeUtils.LONG_SIZE * blockIds.length);
-
- ShuffleBlockId[] shuffleBlockIds = Arrays.stream(blockIds)
- .map(blockId -> (ShuffleBlockId) BlockId.apply(blockId)).toArray(ShuffleBlockId[]::new);
-
- // Submits N implicit get requests without callback
- submitFetchOffsets(endpoint, shuffleBlockIds, dataAddresses, offsetMemory);
-
- // flush guarantees that all that requests completes when callback is called.
- // TODO: fix https://github.com/openucx/ucx/issues/4267 and use endpoint flush.
- workerWrapper.worker().flushNonBlocking(
- new OnOffsetsFetchCallback(shuffleBlockIds, endpoint, listener, offsetMemory,
- dataAddresses, dataRkeysCache));
- shuffleReadMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime);
- }
-
- @Override
- public void close() {
- offsetRkeysCache.values().forEach(UcpRemoteKey::close);
- dataRkeysCache.values().forEach(UcpRemoteKey::close);
- logger.info("Shuffle read metrics, fetch wait time: {}ms", shuffleReadMetrics.fetchWaitTime());
- }
}
diff --git a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_4/OnOffsetsFetchCallback.java b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_4/OnOffsetsFetchCallback.java
index d04e541c..86810d22 100755
--- a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_4/OnOffsetsFetchCallback.java
+++ b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_4/OnOffsetsFetchCallback.java
@@ -22,49 +22,49 @@
* Callback, called when got all offsets for blocks
*/
public class OnOffsetsFetchCallback extends ReducerCallback {
- private final RegisteredMemory offsetMemory;
- private final long[] dataAddresses;
- private Map dataRkeysCache;
+ private final RegisteredMemory offsetMemory;
+ private final long[] dataAddresses;
+ private Map dataRkeysCache;
- public OnOffsetsFetchCallback(ShuffleBlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener,
- RegisteredMemory offsetMemory, long[] dataAddresses,
- Map dataRkeysCache) {
- super(blockIds, endpoint, listener);
- this.offsetMemory = offsetMemory;
- this.dataAddresses = dataAddresses;
- this.dataRkeysCache = dataRkeysCache;
- }
-
- @Override
- public void onSuccess(UcpRequest request) {
- ByteBuffer resultOffset = offsetMemory.getBuffer();
- long totalSize = 0;
- int[] sizes = new int[blockIds.length];
- int offsetSize = UnsafeUtils.LONG_SIZE;
- for (int i = 0; i < blockIds.length; i++) {
- // Blocks in metadata buffer are in form | blockOffsetStart | blockOffsetEnd |
- long blockOffset = resultOffset.getLong(i * 2 * offsetSize);
- long blockLength = resultOffset.getLong(i * 2 * offsetSize + offsetSize) - blockOffset;
- assert (blockLength > 0) && (blockLength <= Integer.MAX_VALUE);
- sizes[i] = (int) blockLength;
- totalSize += blockLength;
- dataAddresses[i] += blockOffset;
+ public OnOffsetsFetchCallback(ShuffleBlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener,
+ RegisteredMemory offsetMemory, long[] dataAddresses,
+ Map dataRkeysCache) {
+ super(blockIds, endpoint, listener);
+ this.offsetMemory = offsetMemory;
+ this.dataAddresses = dataAddresses;
+ this.dataRkeysCache = dataRkeysCache;
}
- assert (totalSize > 0) && (totalSize < Integer.MAX_VALUE);
- mempool.put(offsetMemory);
- RegisteredMemory blocksMemory = mempool.get((int) totalSize);
+ @Override
+ public void onSuccess(UcpRequest request) {
+ ByteBuffer resultOffset = offsetMemory.getBuffer();
+ long totalSize = 0;
+ int[] sizes = new int[blockIds.length];
+ int offsetSize = UnsafeUtils.LONG_SIZE;
+ for (int i = 0; i < blockIds.length; i++) {
+ // Blocks in metadata buffer are in form | blockOffsetStart | blockOffsetEnd |
+ long blockOffset = resultOffset.getLong(i * 2 * offsetSize);
+ long blockLength = resultOffset.getLong(i * 2 * offsetSize + offsetSize) - blockOffset;
+ assert (blockLength > 0) && (blockLength <= Integer.MAX_VALUE);
+ sizes[i] = (int) blockLength;
+ totalSize += blockLength;
+ dataAddresses[i] += blockOffset;
+ }
- long offset = 0;
- // Submits N fetch blocks requests
- for (int i = 0; i < blockIds.length; i++) {
- endpoint.getNonBlockingImplicit(dataAddresses[i], dataRkeysCache.get(((ShuffleBlockId)blockIds[i]).mapId()),
- UcxUtils.getAddress(blocksMemory.getBuffer()) + offset, sizes[i]);
- offset += sizes[i];
- }
+ assert (totalSize > 0) && (totalSize < Integer.MAX_VALUE);
+ mempool.put(offsetMemory);
+ RegisteredMemory blocksMemory = mempool.get((int) totalSize);
- // Process blocks when all fetched.
- // Flush guarantees that callback would invoke when all fetch requests will completed.
- endpoint.flushNonBlocking(new OnBlocksFetchCallback(this, blocksMemory, sizes));
- }
+ long offset = 0;
+ // Submits N fetch blocks requests
+ for (int i = 0; i < blockIds.length; i++) {
+ endpoint.getNonBlockingImplicit(dataAddresses[i], dataRkeysCache.get(((ShuffleBlockId) blockIds[i]).mapId()),
+ UcxUtils.getAddress(blocksMemory.getBuffer()) + offset, sizes[i]);
+ offset += sizes[i];
+ }
+
+ // Process blocks when all fetched.
+ // Flush guarantees that callback would invoke when all fetch requests will completed.
+ endpoint.flushNonBlocking(new OnBlocksFetchCallback(this, blocksMemory, sizes));
+ }
}
diff --git a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_4/UcxShuffleClient.java b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_4/UcxShuffleClient.java
index e2c499c5..4036b639 100755
--- a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_4/UcxShuffleClient.java
+++ b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_4/UcxShuffleClient.java
@@ -27,85 +27,84 @@
import java.util.HashMap;
public class UcxShuffleClient extends ShuffleClient {
- private final MemoryPool mempool;
- private static final Logger logger = LoggerFactory.getLogger(UcxShuffleClient.class);
- private final UcxShuffleManager ucxShuffleManager;
- private final TempShuffleReadMetrics shuffleReadMetrics;
- private final UcxWorkerWrapper workerWrapper;
- final HashMap offsetRkeysCache = new HashMap<>();
- final HashMap dataRkeysCache = new HashMap<>();
-
- public UcxShuffleClient(TempShuffleReadMetrics shuffleReadMetrics,
- UcxWorkerWrapper workerWrapper) {
- this.ucxShuffleManager = (UcxShuffleManager) SparkEnv.get().shuffleManager();
- this.mempool = ucxShuffleManager.ucxNode().getMemoryPool();
- this.shuffleReadMetrics = shuffleReadMetrics;
- this.workerWrapper = workerWrapper;
- }
-
- /**
- * Submits n non blocking fetch offsets to get needed offsets for n blocks.
- */
- private void submitFetchOffsets(UcpEndpoint endpoint, ShuffleBlockId[] blockIds,
- long[] dataAddresses, RegisteredMemory offsetMemory) {
- DriverMetadata driverMetadata = workerWrapper.fetchDriverMetadataBuffer(blockIds[0].shuffleId());
- for (int i = 0; i < blockIds.length; i++) {
- ShuffleBlockId blockId = blockIds[i];
-
- long offsetAddress = driverMetadata.offsetAddress(blockId.mapId());
- dataAddresses[i] = driverMetadata.dataAddress(blockId.mapId());
-
- offsetRkeysCache.computeIfAbsent(blockId.mapId(), mapId ->
- endpoint.unpackRemoteKey(driverMetadata.offsetRkey(blockId.mapId())));
-
- dataRkeysCache.computeIfAbsent(blockId.mapId(), mapId ->
- endpoint.unpackRemoteKey(driverMetadata.dataRkey(blockId.mapId())));
-
- endpoint.getNonBlockingImplicit(
- offsetAddress + blockId.reduceId() * UnsafeUtils.LONG_SIZE,
- offsetRkeysCache.get(blockId.mapId()),
- UcxUtils.getAddress(offsetMemory.getBuffer()) + (i * 2L * UnsafeUtils.LONG_SIZE),
- 2L * UnsafeUtils.LONG_SIZE);
+ private final MemoryPool mempool;
+ private static final Logger logger = LoggerFactory.getLogger(UcxShuffleClient.class);
+ private final UcxShuffleManager ucxShuffleManager;
+ private final TempShuffleReadMetrics shuffleReadMetrics;
+ private final UcxWorkerWrapper workerWrapper;
+ final HashMap offsetRkeysCache = new HashMap<>();
+ final HashMap dataRkeysCache = new HashMap<>();
+
+ public UcxShuffleClient(TempShuffleReadMetrics shuffleReadMetrics,
+ UcxWorkerWrapper workerWrapper) {
+ this.ucxShuffleManager = (UcxShuffleManager) SparkEnv.get().shuffleManager();
+ this.mempool = ucxShuffleManager.ucxNode().getMemoryPool();
+ this.shuffleReadMetrics = shuffleReadMetrics;
+ this.workerWrapper = workerWrapper;
+ }
+
+ /**
+ * Submits n non blocking fetch offsets to get needed offsets for n blocks.
+ */
+ private void submitFetchOffsets(UcpEndpoint endpoint, ShuffleBlockId[] blockIds,
+ long[] dataAddresses, RegisteredMemory offsetMemory) {
+ DriverMetadata driverMetadata = workerWrapper.fetchDriverMetadataBuffer(blockIds[0].shuffleId());
+ for (int i = 0; i < blockIds.length; i++) {
+ ShuffleBlockId blockId = blockIds[i];
+
+ long offsetAddress = driverMetadata.offsetAddress(blockId.mapId());
+ dataAddresses[i] = driverMetadata.dataAddress(blockId.mapId());
+
+ offsetRkeysCache.computeIfAbsent(blockId.mapId(), mapId ->
+ endpoint.unpackRemoteKey(driverMetadata.offsetRkey(blockId.mapId())));
+
+ dataRkeysCache.computeIfAbsent(blockId.mapId(), mapId ->
+ endpoint.unpackRemoteKey(driverMetadata.dataRkey(blockId.mapId())));
+
+ endpoint.getNonBlockingImplicit(
+ offsetAddress + blockId.reduceId() * UnsafeUtils.LONG_SIZE,
+ offsetRkeysCache.get(blockId.mapId()),
+ UcxUtils.getAddress(offsetMemory.getBuffer()) + (i * 2L * UnsafeUtils.LONG_SIZE),
+ 2L * UnsafeUtils.LONG_SIZE);
+ }
+ }
+
+ /**
+ * Reducer entry point. Fetches remote blocks, using 2 ucp_get calls.
+ * This method is inside ShuffleFetchIterator's for loop over hosts.
+ * First fetches block offset from index file, and then fetches block itself.
+ */
+ @Override
+ public void fetchBlocks(String host, int port, String execId,
+ String[] blockIds, BlockFetchingListener listener,
+ DownloadFileManager downloadFileManager) {
+ long startTime = System.currentTimeMillis();
+
+ BlockManagerId blockManagerId = BlockManagerId.apply(execId, host, port, Option.empty());
+ UcpEndpoint endpoint = workerWrapper.getConnection(blockManagerId);
+
+ long[] dataAddresses = new long[blockIds.length];
+
+ // Need to fetch 2 long offsets current block + next block to calculate exact block size.
+ RegisteredMemory offsetMemory = mempool.get(2 * UnsafeUtils.LONG_SIZE * blockIds.length);
+
+ ShuffleBlockId[] shuffleBlockIds = Arrays.stream(blockIds)
+ .map(blockId -> (ShuffleBlockId) BlockId.apply(blockId)).toArray(ShuffleBlockId[]::new);
+
+ // Submits N implicit get requests without callback
+ submitFetchOffsets(endpoint, shuffleBlockIds, dataAddresses, offsetMemory);
+
+ // flush guarantees that all that requests completes when callback is called.
+ endpoint.flushNonBlocking(
+ new OnOffsetsFetchCallback(shuffleBlockIds, endpoint, listener, offsetMemory,
+ dataAddresses, dataRkeysCache));
+ shuffleReadMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime);
+ }
+
+ @Override
+ public void close() {
+ offsetRkeysCache.values().forEach(UcpRemoteKey::close);
+ dataRkeysCache.values().forEach(UcpRemoteKey::close);
+ logger.info("Shuffle read metrics, fetch wait time: {}ms", shuffleReadMetrics.fetchWaitTime());
}
- }
-
- /**
- * Reducer entry point. Fetches remote blocks, using 2 ucp_get calls.
- * This method is inside ShuffleFetchIterator's for loop over hosts.
- * First fetches block offset from index file, and then fetches block itself.
- */
- @Override
- public void fetchBlocks(String host, int port, String execId,
- String[] blockIds, BlockFetchingListener listener,
- DownloadFileManager downloadFileManager) {
- long startTime = System.currentTimeMillis();
-
- BlockManagerId blockManagerId = BlockManagerId.apply(execId, host, port, Option.empty());
- UcpEndpoint endpoint = workerWrapper.getConnection(blockManagerId);
-
- long[] dataAddresses = new long[blockIds.length];
-
- // Need to fetch 2 long offsets current block + next block to calculate exact block size.
- RegisteredMemory offsetMemory = mempool.get(2 * UnsafeUtils.LONG_SIZE * blockIds.length);
-
- ShuffleBlockId[] shuffleBlockIds = Arrays.stream(blockIds)
- .map(blockId -> (ShuffleBlockId) BlockId.apply(blockId)).toArray(ShuffleBlockId[]::new);
-
- // Submits N implicit get requests without callback
- submitFetchOffsets(endpoint, shuffleBlockIds, dataAddresses, offsetMemory);
-
- // flush guarantees that all that requests completes when callback is called.
- // TODO: fix https://github.com/openucx/ucx/issues/4267 and use endpoint flush.
- workerWrapper.worker().flushNonBlocking(
- new OnOffsetsFetchCallback(shuffleBlockIds, endpoint, listener, offsetMemory,
- dataAddresses, dataRkeysCache));
- shuffleReadMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime);
- }
-
- @Override
- public void close() {
- offsetRkeysCache.values().forEach(UcpRemoteKey::close);
- dataRkeysCache.values().forEach(UcpRemoteKey::close);
- logger.info("Shuffle read metrics, fetch wait time: {}ms", shuffleReadMetrics.fetchWaitTime());
- }
}
diff --git a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_0/OnOffsetsFetchCallback.java b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_0/OnOffsetsFetchCallback.java
index 372955b5..e792e6db 100755
--- a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_0/OnOffsetsFetchCallback.java
+++ b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_0/OnOffsetsFetchCallback.java
@@ -25,69 +25,69 @@
* Callback, called when got all offsets for blocks
*/
public class OnOffsetsFetchCallback extends ReducerCallback {
- private final RegisteredMemory offsetMemory;
- private final long[] dataAddresses;
- private Map dataRkeysCache;
- private final Map mapId2PartitionId;
+ private final RegisteredMemory offsetMemory;
+ private final long[] dataAddresses;
+ private Map dataRkeysCache;
+ private final Map mapId2PartitionId;
- public OnOffsetsFetchCallback(BlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener,
- RegisteredMemory offsetMemory, long[] dataAddresses,
- Map dataRkeysCache,
- Map mapId2PartitionId) {
- super(blockIds, endpoint, listener);
- this.offsetMemory = offsetMemory;
- this.dataAddresses = dataAddresses;
- this.dataRkeysCache = dataRkeysCache;
- this.mapId2PartitionId = mapId2PartitionId;
- }
+ public OnOffsetsFetchCallback(BlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener,
+ RegisteredMemory offsetMemory, long[] dataAddresses,
+ Map dataRkeysCache,
+ Map mapId2PartitionId) {
+ super(blockIds, endpoint, listener);
+ this.offsetMemory = offsetMemory;
+ this.dataAddresses = dataAddresses;
+ this.dataRkeysCache = dataRkeysCache;
+ this.mapId2PartitionId = mapId2PartitionId;
+ }
- @Override
- public void onSuccess(UcpRequest request) {
- ByteBuffer resultOffset = offsetMemory.getBuffer();
- long totalSize = 0;
- int[] sizes = new int[blockIds.length];
- int offset = 0;
- long blockOffset;
- long blockLength;
- int offsetSize = UnsafeUtils.LONG_SIZE;
- for (int i = 0; i < blockIds.length; i++) {
- // Blocks in metadata buffer are in form | blockOffsetStart | blockOffsetEnd |
- if (blockIds[i] instanceof ShuffleBlockBatchId) {
- ShuffleBlockBatchId blockBatchId = (ShuffleBlockBatchId) blockIds[i];
- int blocksInBatch = blockBatchId.endReduceId() - blockBatchId.startReduceId();
- blockOffset = resultOffset.getLong(offset * 2 * offsetSize);
- blockLength = resultOffset.getLong(offset * 2 * offsetSize + offsetSize * blocksInBatch)
- - blockOffset;
- offset += blocksInBatch;
- } else {
- blockOffset = resultOffset.getLong(offset * 16);
- blockLength = resultOffset.getLong(offset * 16 + 8) - blockOffset;
- offset++;
- }
+ @Override
+ public void onSuccess(UcpRequest request) {
+ ByteBuffer resultOffset = offsetMemory.getBuffer();
+ long totalSize = 0;
+ int[] sizes = new int[blockIds.length];
+ int offset = 0;
+ long blockOffset;
+ long blockLength;
+ int offsetSize = UnsafeUtils.LONG_SIZE;
+ for (int i = 0; i < blockIds.length; i++) {
+ // Blocks in metadata buffer are in form | blockOffsetStart | blockOffsetEnd |
+ if (blockIds[i] instanceof ShuffleBlockBatchId) {
+ ShuffleBlockBatchId blockBatchId = (ShuffleBlockBatchId) blockIds[i];
+ int blocksInBatch = blockBatchId.endReduceId() - blockBatchId.startReduceId();
+ blockOffset = resultOffset.getLong(offset * 2 * offsetSize);
+ blockLength = resultOffset.getLong(offset * 2 * offsetSize + offsetSize * blocksInBatch)
+ - blockOffset;
+ offset += blocksInBatch;
+ } else {
+ blockOffset = resultOffset.getLong(offset * 16);
+ blockLength = resultOffset.getLong(offset * 16 + 8) - blockOffset;
+ offset++;
+ }
- assert (blockLength > 0) && (blockLength <= Integer.MAX_VALUE);
- sizes[i] = (int) blockLength;
- totalSize += blockLength;
- dataAddresses[i] += blockOffset;
- }
+ assert (blockLength > 0) && (blockLength <= Integer.MAX_VALUE);
+ sizes[i] = (int) blockLength;
+ totalSize += blockLength;
+ dataAddresses[i] += blockOffset;
+ }
- assert (totalSize > 0) && (totalSize < Integer.MAX_VALUE);
- mempool.put(offsetMemory);
- RegisteredMemory blocksMemory = mempool.get((int) totalSize);
+ assert (totalSize > 0) && (totalSize < Integer.MAX_VALUE);
+ mempool.put(offsetMemory);
+ RegisteredMemory blocksMemory = mempool.get((int) totalSize);
- offset = 0;
- // Submits N fetch blocks requests
- for (int i = 0; i < blockIds.length; i++) {
- int mapPartitionId = (blockIds[i] instanceof ShuffleBlockId) ?
- mapId2PartitionId.get(((ShuffleBlockId)blockIds[i]).mapId()) :
- mapId2PartitionId.get(((ShuffleBlockBatchId)blockIds[i]).mapId());
- endpoint.getNonBlockingImplicit(dataAddresses[i], dataRkeysCache.get(mapPartitionId),
- UcxUtils.getAddress(blocksMemory.getBuffer()) + offset, sizes[i]);
- offset += sizes[i];
- }
+ offset = 0;
+ // Submits N fetch blocks requests
+ for (int i = 0; i < blockIds.length; i++) {
+ int mapPartitionId = (blockIds[i] instanceof ShuffleBlockId) ?
+ mapId2PartitionId.get(((ShuffleBlockId) blockIds[i]).mapId()) :
+ mapId2PartitionId.get(((ShuffleBlockBatchId) blockIds[i]).mapId());
+ endpoint.getNonBlockingImplicit(dataAddresses[i], dataRkeysCache.get(mapPartitionId),
+ UcxUtils.getAddress(blocksMemory.getBuffer()) + offset, sizes[i]);
+ offset += sizes[i];
+ }
- // Process blocks when all fetched.
- // Flush guarantees that callback would invoke when all fetch requests will completed.
- endpoint.flushNonBlocking(new OnBlocksFetchCallback(this, blocksMemory, sizes));
- }
+ // Process blocks when all fetched.
+ // Flush guarantees that callback would invoke when all fetch requests will completed.
+ endpoint.flushNonBlocking(new OnBlocksFetchCallback(this, blocksMemory, sizes));
+ }
}
diff --git a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_0/UcxShuffleClient.java b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_0/UcxShuffleClient.java
index 77a3d87c..c7d1360e 100755
--- a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_0/UcxShuffleClient.java
+++ b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_0/UcxShuffleClient.java
@@ -27,110 +27,109 @@
import java.util.Map;
public class UcxShuffleClient extends BlockStoreClient {
- private static final Logger logger = LoggerFactory.getLogger(UcxShuffleClient.class);
- private final UcxWorkerWrapper workerWrapper;
- private final Map mapId2PartitionId;
- private final TempShuffleReadMetrics shuffleReadMetrics;
- private final int shuffleId;
- final HashMap offsetRkeysCache = new HashMap<>();
- final HashMap dataRkeysCache = new HashMap<>();
-
-
- public UcxShuffleClient(int shuffleId, UcxWorkerWrapper workerWrapper,
- Map mapId2PartitionId, TempShuffleReadMetrics shuffleReadMetrics) {
- this.workerWrapper = workerWrapper;
- this.shuffleId = shuffleId;
- this.mapId2PartitionId = mapId2PartitionId;
- this.shuffleReadMetrics = shuffleReadMetrics;
- }
-
- /**
- * Submits n non blocking fetch offsets to get needed offsets for n blocks.
- */
- private void submitFetchOffsets(UcpEndpoint endpoint, BlockId[] blockIds,
- RegisteredMemory offsetMemory,
- long[] dataAddresses) {
- DriverMetadata driverMetadata = workerWrapper.fetchDriverMetadataBuffer(shuffleId);
- long offset = 0;
- int startReduceId;
- long size;
-
- for (int i = 0; i < blockIds.length; i++) {
- BlockId blockId = blockIds[i];
- int mapIdpartition;
-
- if (blockId instanceof ShuffleBlockId) {
- ShuffleBlockId shuffleBlockId = (ShuffleBlockId) blockId;
- mapIdpartition = mapId2PartitionId.get(shuffleBlockId.mapId());
- size = 2L * UnsafeUtils.LONG_SIZE;
- startReduceId = shuffleBlockId.reduceId();
- } else {
- ShuffleBlockBatchId shuffleBlockBatchId = (ShuffleBlockBatchId) blockId;
- mapIdpartition = mapId2PartitionId.get(shuffleBlockBatchId.mapId());
- size = (shuffleBlockBatchId.endReduceId() - shuffleBlockBatchId.startReduceId())
- * 2L * UnsafeUtils.LONG_SIZE;
- startReduceId = shuffleBlockBatchId.startReduceId();
- }
-
- long offsetAddress = driverMetadata.offsetAddress(mapIdpartition);
- dataAddresses[i] = driverMetadata.dataAddress(mapIdpartition);
-
- offsetRkeysCache.computeIfAbsent(mapIdpartition, mapId ->
- endpoint.unpackRemoteKey(driverMetadata.offsetRkey(mapIdpartition)));
-
- dataRkeysCache.computeIfAbsent(mapIdpartition, mapId ->
- endpoint.unpackRemoteKey(driverMetadata.dataRkey(mapIdpartition)));
-
- endpoint.getNonBlockingImplicit(
- offsetAddress + startReduceId * UnsafeUtils.LONG_SIZE,
- offsetRkeysCache.get(mapIdpartition),
- UcxUtils.getAddress(offsetMemory.getBuffer()) + offset,
- size);
-
- offset += size;
+ private static final Logger logger = LoggerFactory.getLogger(UcxShuffleClient.class);
+ private final UcxWorkerWrapper workerWrapper;
+ private final Map mapId2PartitionId;
+ private final TempShuffleReadMetrics shuffleReadMetrics;
+ private final int shuffleId;
+ final HashMap offsetRkeysCache = new HashMap<>();
+ final HashMap dataRkeysCache = new HashMap<>();
+
+
+ public UcxShuffleClient(int shuffleId, UcxWorkerWrapper workerWrapper,
+ Map mapId2PartitionId, TempShuffleReadMetrics shuffleReadMetrics) {
+ this.workerWrapper = workerWrapper;
+ this.shuffleId = shuffleId;
+ this.mapId2PartitionId = mapId2PartitionId;
+ this.shuffleReadMetrics = shuffleReadMetrics;
}
- }
-
- @Override
- public void fetchBlocks(String host, int port, String execId, String[] blockIds, BlockFetchingListener listener,
- DownloadFileManager downloadFileManager) {
- long startTime = System.currentTimeMillis();
- BlockManagerId blockManagerId = BlockManagerId.apply(execId, host, port, Option.empty());
- UcpEndpoint endpoint = workerWrapper.getConnection(blockManagerId);
- long[] dataAddresses = new long[blockIds.length];
- int totalBlocks = 0;
-
- BlockId[] blocks = new BlockId[blockIds.length];
-
- for (int i = 0; i < blockIds.length; i++) {
- blocks[i] = BlockId.apply(blockIds[i]);
- if (blocks[i] instanceof ShuffleBlockId) {
- totalBlocks += 1;
- } else {
- ShuffleBlockBatchId blockBatchId = (ShuffleBlockBatchId)blocks[i];
- totalBlocks += (blockBatchId.endReduceId() - blockBatchId.startReduceId());
- }
+
+ /**
+ * Submits n non blocking fetch offsets to get needed offsets for n blocks.
+ */
+ private void submitFetchOffsets(UcpEndpoint endpoint, BlockId[] blockIds,
+ RegisteredMemory offsetMemory,
+ long[] dataAddresses) {
+ DriverMetadata driverMetadata = workerWrapper.fetchDriverMetadataBuffer(shuffleId);
+ long offset = 0;
+ int startReduceId;
+ long size;
+
+ for (int i = 0; i < blockIds.length; i++) {
+ BlockId blockId = blockIds[i];
+ int mapIdpartition;
+
+ if (blockId instanceof ShuffleBlockId) {
+ ShuffleBlockId shuffleBlockId = (ShuffleBlockId) blockId;
+ mapIdpartition = mapId2PartitionId.get(shuffleBlockId.mapId());
+ size = 2L * UnsafeUtils.LONG_SIZE;
+ startReduceId = shuffleBlockId.reduceId();
+ } else {
+ ShuffleBlockBatchId shuffleBlockBatchId = (ShuffleBlockBatchId) blockId;
+ mapIdpartition = mapId2PartitionId.get(shuffleBlockBatchId.mapId());
+ size = (shuffleBlockBatchId.endReduceId() - shuffleBlockBatchId.startReduceId())
+ * 2L * UnsafeUtils.LONG_SIZE;
+ startReduceId = shuffleBlockBatchId.startReduceId();
+ }
+
+ long offsetAddress = driverMetadata.offsetAddress(mapIdpartition);
+ dataAddresses[i] = driverMetadata.dataAddress(mapIdpartition);
+
+ offsetRkeysCache.computeIfAbsent(mapIdpartition, mapId ->
+ endpoint.unpackRemoteKey(driverMetadata.offsetRkey(mapIdpartition)));
+
+ dataRkeysCache.computeIfAbsent(mapIdpartition, mapId ->
+ endpoint.unpackRemoteKey(driverMetadata.dataRkey(mapIdpartition)));
+
+ endpoint.getNonBlockingImplicit(
+ offsetAddress + startReduceId * UnsafeUtils.LONG_SIZE,
+ offsetRkeysCache.get(mapIdpartition),
+ UcxUtils.getAddress(offsetMemory.getBuffer()) + offset,
+ size);
+
+ offset += size;
+ }
+ }
+
+ @Override
+ public void fetchBlocks(String host, int port, String execId, String[] blockIds, BlockFetchingListener listener,
+ DownloadFileManager downloadFileManager) {
+ long startTime = System.currentTimeMillis();
+ BlockManagerId blockManagerId = BlockManagerId.apply(execId, host, port, Option.empty());
+ UcpEndpoint endpoint = workerWrapper.getConnection(blockManagerId);
+ long[] dataAddresses = new long[blockIds.length];
+ int totalBlocks = 0;
+
+ BlockId[] blocks = new BlockId[blockIds.length];
+
+ for (int i = 0; i < blockIds.length; i++) {
+ blocks[i] = BlockId.apply(blockIds[i]);
+ if (blocks[i] instanceof ShuffleBlockId) {
+ totalBlocks += 1;
+ } else {
+ ShuffleBlockBatchId blockBatchId = (ShuffleBlockBatchId) blocks[i];
+ totalBlocks += (blockBatchId.endReduceId() - blockBatchId.startReduceId());
+ }
+ }
+
+ RegisteredMemory offsetMemory = ((UcxShuffleManager) SparkEnv.get().shuffleManager())
+ .ucxNode().getMemoryPool().get(totalBlocks * 2 * UnsafeUtils.LONG_SIZE);
+ // Submits N implicit get requests without callback
+ submitFetchOffsets(endpoint, blocks, offsetMemory, dataAddresses);
+
+ // flush guarantees that all that requests completes when callback is called.
+ endpoint.flushNonBlocking(
+ new OnOffsetsFetchCallback(blocks, endpoint, listener, offsetMemory,
+ dataAddresses, dataRkeysCache, mapId2PartitionId));
+
+ shuffleReadMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime);
}
- RegisteredMemory offsetMemory = ((UcxShuffleManager)SparkEnv.get().shuffleManager())
- .ucxNode().getMemoryPool().get(totalBlocks * 2 * UnsafeUtils.LONG_SIZE);
- // Submits N implicit get requests without callback
- submitFetchOffsets(endpoint, blocks, offsetMemory, dataAddresses);
-
- // flush guarantees that all that requests completes when callback is called.
- // TODO: fix https://github.com/openucx/ucx/issues/4267 and use endpoint flush.
- workerWrapper.worker().flushNonBlocking(
- new OnOffsetsFetchCallback(blocks, endpoint, listener, offsetMemory,
- dataAddresses, dataRkeysCache, mapId2PartitionId));
-
- shuffleReadMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime);
- }
-
- @Override
- public void close() {
- offsetRkeysCache.values().forEach(UcpRemoteKey::close);
- dataRkeysCache.values().forEach(UcpRemoteKey::close);
- logger.info("Shuffle read metrics, fetch wait time: {}ms", shuffleReadMetrics.fetchWaitTime());
- }
+ @Override
+ public void close() {
+ offsetRkeysCache.values().forEach(UcpRemoteKey::close);
+ dataRkeysCache.values().forEach(UcpRemoteKey::close);
+ logger.info("Shuffle read metrics, fetch wait time: {}ms", shuffleReadMetrics.fetchWaitTime());
+ }
}
diff --git a/src/main/java/org/apache/spark/shuffle/ucx/rpc/RpcConnectionCallback.java b/src/main/java/org/apache/spark/shuffle/ucx/rpc/RpcConnectionCallback.java
index a81570e3..b0d885f2 100755
--- a/src/main/java/org/apache/spark/shuffle/ucx/rpc/RpcConnectionCallback.java
+++ b/src/main/java/org/apache/spark/shuffle/ucx/rpc/RpcConnectionCallback.java
@@ -28,72 +28,72 @@
* introduce cluster to connected executor.
*/
public class RpcConnectionCallback extends UcxCallback {
- private static final Logger logger = LoggerFactory.getLogger(RpcConnectionCallback.class);
- private final ByteBuffer metadataBuffer;
- private final boolean isDriver;
- private final UcxNode ucxNode;
- private static final ConcurrentMap rpcConnections =
- UcxNode.getRpcConnections();
- private static final ConcurrentMap workerAdresses =
- UcxNode.getWorkerAddresses();
+ private static final Logger logger = LoggerFactory.getLogger(RpcConnectionCallback.class);
+ private final ByteBuffer metadataBuffer;
+ private final boolean isDriver;
+ private final UcxNode ucxNode;
+ private static final ConcurrentMap rpcConnections =
+ UcxNode.getRpcConnections();
+ private static final ConcurrentMap workerAdresses =
+ UcxNode.getWorkerAddresses();
- RpcConnectionCallback(ByteBuffer metadataBuffer, boolean isDriver, UcxNode ucxNode) {
- this.metadataBuffer = metadataBuffer;
- this.isDriver = isDriver;
- this.ucxNode = ucxNode;
- }
+ RpcConnectionCallback(ByteBuffer metadataBuffer, boolean isDriver, UcxNode ucxNode) {
+ this.metadataBuffer = metadataBuffer;
+ this.isDriver = isDriver;
+ this.ucxNode = ucxNode;
+ }
- @Override
- public void onSuccess(UcpRequest request) {
- int workerAddressSize = metadataBuffer.getInt();
- ByteBuffer workerAddress = Platform.allocateDirectBuffer(workerAddressSize);
+ @Override
+ public void onSuccess(UcpRequest request) {
+ int workerAddressSize = metadataBuffer.getInt();
+ ByteBuffer workerAddress = Platform.allocateDirectBuffer(workerAddressSize);
- // Copy worker address from metadata buffer to separate buffer.
- final ByteBuffer metadataView = metadataBuffer.duplicate();
- metadataView.limit(metadataView.position() + workerAddressSize);
- workerAddress.put(metadataView);
- metadataBuffer.position(metadataBuffer.position() + workerAddressSize);
+ // Copy worker address from metadata buffer to separate buffer.
+ final ByteBuffer metadataView = metadataBuffer.duplicate();
+ metadataView.limit(metadataView.position() + workerAddressSize);
+ workerAddress.put(metadataView);
+ metadataBuffer.position(metadataBuffer.position() + workerAddressSize);
- BlockManagerId blockManagerId;
- try {
- blockManagerId = SerializableBlockManagerID
- .deserializeBlockManagerID(metadataBuffer);
- } catch (IOException e) {
- String errorMsg = String.format("Failed to deserialize BlockManagerId: %s", e.getMessage());
- throw new UcxException(errorMsg);
- }
- logger.debug("Received RPC message from {}", blockManagerId);
- UcpWorker globalWorker = ucxNode.getGlobalWorker();
+ BlockManagerId blockManagerId;
+ try {
+ blockManagerId = SerializableBlockManagerID
+ .deserializeBlockManagerID(metadataBuffer);
+ } catch (IOException e) {
+ String errorMsg = String.format("Failed to deserialize BlockManagerId: %s", e.getMessage());
+ throw new UcxException(errorMsg);
+ }
+ logger.debug("Received RPC message from {}", blockManagerId);
+ UcpWorker globalWorker = ucxNode.getGlobalWorker();
- workerAddress.clear();
+ workerAddress.clear();
- if (isDriver) {
- metadataBuffer.clear();
- UcpEndpoint newConnection = globalWorker.newEndpoint(
- new UcpEndpointParams().setPeerErrorHandlingMode()
- .setUcpAddress(workerAddress));
- // For each existing connection
- rpcConnections.forEach((connection, connectionMetadata) -> {
- // send address of joined worker to already connected workers
- connection.sendTaggedNonBlocking(metadataBuffer, null);
- // introduce other workers to joined worker
- newConnection.sendTaggedNonBlocking(connectionMetadata, null);
- });
+ if (isDriver) {
+ metadataBuffer.clear();
+ UcpEndpoint newConnection = globalWorker.newEndpoint(
+ new UcpEndpointParams().setPeerErrorHandlingMode()
+ .setUcpAddress(workerAddress));
+ // For each existing connection
+ rpcConnections.forEach((connection, connectionMetadata) -> {
+ // send address of joined worker to already connected workers
+ connection.sendTaggedNonBlocking(metadataBuffer, null);
+ // introduce other workers to joined worker
+ newConnection.sendTaggedNonBlocking(connectionMetadata, null);
+ });
- rpcConnections.put(newConnection, metadataBuffer);
- }
- workerAdresses.put(blockManagerId, workerAddress);
- synchronized (workerAdresses) {
- workerAdresses.notifyAll();
+ rpcConnections.put(newConnection, metadataBuffer);
+ }
+ workerAdresses.put(blockManagerId, workerAddress);
+ synchronized (workerAdresses) {
+ workerAdresses.notifyAll();
+ }
}
- }
- @Override
- public void onError(int ucsStatus, String errorMsg) {
- // UCS_ERR_CANCELED = -16,
- if (ucsStatus != -16) {
- logger.error("Request error: {}", errorMsg);
- throw new UcxException(errorMsg);
+ @Override
+ public void onError(int ucsStatus, String errorMsg) {
+ // UCS_ERR_CANCELED = -16,
+ if (ucsStatus != -16) {
+ logger.error("Request error: {}", errorMsg);
+ throw new UcxException(errorMsg);
+ }
}
- }
}
diff --git a/src/main/java/org/apache/spark/shuffle/ucx/rpc/SerializableBlockManagerID.java b/src/main/java/org/apache/spark/shuffle/ucx/rpc/SerializableBlockManagerID.java
index dfc5af7a..5560e418 100755
--- a/src/main/java/org/apache/spark/shuffle/ucx/rpc/SerializableBlockManagerID.java
+++ b/src/main/java/org/apache/spark/shuffle/ucx/rpc/SerializableBlockManagerID.java
@@ -14,19 +14,19 @@
*/
public class SerializableBlockManagerID {
- public static void serializeBlockManagerID(BlockManagerId blockManagerId,
- ByteBuffer metadataBuffer) throws IOException {
- ObjectOutputStream oos = new ObjectOutputStream(
- new ByteBufferOutputStream(metadataBuffer));
- blockManagerId.writeExternal(oos);
- oos.close();
- }
+ public static void serializeBlockManagerID(BlockManagerId blockManagerId,
+ ByteBuffer metadataBuffer) throws IOException {
+ ObjectOutputStream oos = new ObjectOutputStream(
+ new ByteBufferOutputStream(metadataBuffer));
+ blockManagerId.writeExternal(oos);
+ oos.close();
+ }
- static BlockManagerId deserializeBlockManagerID(ByteBuffer metadataBuffer) throws IOException {
- ObjectInputStream ois =
- new ObjectInputStream(new ByteBufferInputStream(metadataBuffer));
- BlockManagerId blockManagerId = BlockManagerId.apply(ois);
- ois.close();
- return blockManagerId;
- }
+ static BlockManagerId deserializeBlockManagerID(ByteBuffer metadataBuffer) throws IOException {
+ ObjectInputStream ois =
+ new ObjectInputStream(new ByteBufferInputStream(metadataBuffer));
+ BlockManagerId blockManagerId = BlockManagerId.apply(ois);
+ ois.close();
+ return blockManagerId;
+ }
}
diff --git a/src/main/java/org/apache/spark/shuffle/ucx/rpc/UcxListenerThread.java b/src/main/java/org/apache/spark/shuffle/ucx/rpc/UcxListenerThread.java
index ce2ecff5..e2252d06 100755
--- a/src/main/java/org/apache/spark/shuffle/ucx/rpc/UcxListenerThread.java
+++ b/src/main/java/org/apache/spark/shuffle/ucx/rpc/UcxListenerThread.java
@@ -17,47 +17,47 @@
* Thread for progressing global worker for connection establishment and RPC exchange.
*/
public class UcxListenerThread extends Thread implements Runnable {
- private static final Logger logger = LoggerFactory.getLogger(UcxListenerThread.class);
- private final UcxNode ucxNode;
- private final boolean isDriver;
- private final UcpWorker globalWorker;
+ private static final Logger logger = LoggerFactory.getLogger(UcxListenerThread.class);
+ private final UcxNode ucxNode;
+ private final boolean isDriver;
+ private final UcpWorker globalWorker;
- public UcxListenerThread(UcxNode ucxNode, boolean isDriver) {
- this.ucxNode = ucxNode;
- this.isDriver = isDriver;
- this.globalWorker = ucxNode.getGlobalWorker();
- setDaemon(true);
- setName("UcxListenerThread");
- }
+ public UcxListenerThread(UcxNode ucxNode, boolean isDriver) {
+ this.ucxNode = ucxNode;
+ this.isDriver = isDriver;
+ this.globalWorker = ucxNode.getGlobalWorker();
+ setDaemon(true);
+ setName("UcxListenerThread");
+ }
- /**
- * 2. Both Driver and Executor. Accept Recv request.
- * If on driver broadcast it to other executors. On executor just save worker addresses.
- */
- private UcpRequest recvRequest() {
- ByteBuffer metadataBuffer = Platform.allocateDirectBuffer(
- ucxNode.getConf().metadataRPCBufferSize());
- RpcConnectionCallback callback = new RpcConnectionCallback(metadataBuffer, isDriver, ucxNode);
- return globalWorker.recvTaggedNonBlocking(metadataBuffer, callback);
- }
+ /**
+ * 2. Both Driver and Executor. Accept Recv request.
+ * If on driver broadcast it to other executors. On executor just save worker addresses.
+ */
+ private UcpRequest recvRequest() {
+ ByteBuffer metadataBuffer = Platform.allocateDirectBuffer(
+ ucxNode.getConf().metadataRPCBufferSize());
+ RpcConnectionCallback callback = new RpcConnectionCallback(metadataBuffer, isDriver, ucxNode);
+ return globalWorker.recvTaggedNonBlocking(metadataBuffer, callback);
+ }
- @Override
- public void run() {
- UcpRequest recv = recvRequest();
- while (!isInterrupted()) {
- if (recv.isCompleted()) {
- // Process 1 recv request at a time.
- recv = recvRequest();
- }
- try {
- if (globalWorker.progress() == 0) {
- globalWorker.waitForEvents();
+ @Override
+ public void run() {
+ UcpRequest recv = recvRequest();
+ while (!isInterrupted()) {
+ if (recv.isCompleted()) {
+ // Process 1 recv request at a time.
+ recv = recvRequest();
+ }
+ try {
+ if (globalWorker.progress() == 0) {
+ globalWorker.waitForEvents();
+ }
+ } catch (Exception e) {
+ logger.error(e.getLocalizedMessage());
+ interrupt();
+ }
}
- } catch (Exception e) {
- logger.error(e.getLocalizedMessage());
- interrupt();
- }
+ globalWorker.cancelRequest(recv);
}
- globalWorker.cancelRequest(recv);
- }
}
diff --git a/src/main/java/org/apache/spark/shuffle/ucx/rpc/UcxRemoteMemory.java b/src/main/java/org/apache/spark/shuffle/ucx/rpc/UcxRemoteMemory.java
index 08626cee..e48b26f2 100755
--- a/src/main/java/org/apache/spark/shuffle/ucx/rpc/UcxRemoteMemory.java
+++ b/src/main/java/org/apache/spark/shuffle/ucx/rpc/UcxRemoteMemory.java
@@ -16,39 +16,40 @@
* spark's mechanism to broadcast tasks.
*/
public class UcxRemoteMemory implements Serializable {
- private long address;
- private ByteBuffer rkeyBuffer;
-
- public UcxRemoteMemory(long address, ByteBuffer rkeyBuffer) {
- this.address = address;
- this.rkeyBuffer = rkeyBuffer;
- }
-
- public UcxRemoteMemory() {}
-
- private void writeObject(ObjectOutputStream out) throws IOException {
- out.writeLong(address);
- out.writeInt(rkeyBuffer.limit());
- byte[] copy = new byte[rkeyBuffer.limit()];
- rkeyBuffer.clear();
- rkeyBuffer.get(copy);
- out.write(copy);
- }
-
- private void readObject(ObjectInputStream in) throws IOException {
- this.address = in.readLong();
- int bufferSize = in.readInt();
- byte[] buffer = new byte[bufferSize];
- in.read(buffer, 0, bufferSize);
- this.rkeyBuffer = ByteBuffer.allocateDirect(bufferSize).put(buffer);
- this.rkeyBuffer.clear();
- }
-
- public long getAddress() {
- return address;
- }
-
- public ByteBuffer getRkeyBuffer() {
- return rkeyBuffer;
- }
+ private long address;
+ private ByteBuffer rkeyBuffer;
+
+ public UcxRemoteMemory(long address, ByteBuffer rkeyBuffer) {
+ this.address = address;
+ this.rkeyBuffer = rkeyBuffer;
+ }
+
+ public UcxRemoteMemory() {
+ }
+
+ private void writeObject(ObjectOutputStream out) throws IOException {
+ out.writeLong(address);
+ out.writeInt(rkeyBuffer.limit());
+ byte[] copy = new byte[rkeyBuffer.limit()];
+ rkeyBuffer.clear();
+ rkeyBuffer.get(copy);
+ out.write(copy);
+ }
+
+ private void readObject(ObjectInputStream in) throws IOException {
+ this.address = in.readLong();
+ int bufferSize = in.readInt();
+ byte[] buffer = new byte[bufferSize];
+ in.read(buffer, 0, bufferSize);
+ this.rkeyBuffer = ByteBuffer.allocateDirect(bufferSize).put(buffer);
+ this.rkeyBuffer.clear();
+ }
+
+ public long getAddress() {
+ return address;
+ }
+
+ public ByteBuffer getRkeyBuffer() {
+ return rkeyBuffer;
+ }
}
diff --git a/src/main/scala/org/apache/spark/shuffle/CommonUcxShuffleBlockResolver.scala b/src/main/scala/org/apache/spark/shuffle/CommonUcxShuffleBlockResolver.scala
index 1405fade..e4dd335f 100755
--- a/src/main/scala/org/apache/spark/shuffle/CommonUcxShuffleBlockResolver.scala
+++ b/src/main/scala/org/apache/spark/shuffle/CommonUcxShuffleBlockResolver.scala
@@ -1,7 +1,7 @@
/*
-* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
-* See file LICENSE for terms.
-*/
+ * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
+ * See file LICENSE for terms.
+ */
package org.apache.spark.shuffle
import java.io.{File, RandomAccessFile}
@@ -14,25 +14,31 @@ import org.openucx.jucx.ucp.{UcpMemMapParams, UcpMemory}
import org.apache.spark.shuffle.ucx.UnsafeUtils
import org.apache.spark.SparkException
-/**
- * Mapper entry point for UcxShuffle plugin. Performs memory registration
- * of data and index files and publish addresses to driver metadata buffer.
- */
-abstract class CommonUcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager)
- extends IndexShuffleBlockResolver(ucxShuffleManager.conf) {
+/** Mapper entry point for UcxShuffle plugin. Performs memory registration
+ * of data and index files and publish addresses to driver metadata buffer.
+ */
+abstract class CommonUcxShuffleBlockResolver(
+ ucxShuffleManager: CommonUcxShuffleManager
+) extends IndexShuffleBlockResolver(ucxShuffleManager.conf) {
private lazy val memPool = ucxShuffleManager.ucxNode.getMemoryPool
// Keep track of registered memory regions to release them when shuffle not needed
- private val fileMappings = new ConcurrentHashMap[ShuffleId, CopyOnWriteArrayList[UcpMemory]].asScala
- private val offsetMappings = new ConcurrentHashMap[ShuffleId, CopyOnWriteArrayList[UcpMemory]].asScala
-
- /**
- * Mapper commit protocol extension. Register index and data files and publish all needed
- * metadata to driver.
- */
- def writeIndexFileAndCommitCommon(shuffleId: ShuffleId, mapId: Int,
- lengths: Array[Long], dataTmp: File,
- indexBackFile: RandomAccessFile, dataBackFile: RandomAccessFile): Unit = {
+ private val fileMappings =
+ new ConcurrentHashMap[ShuffleId, CopyOnWriteArrayList[UcpMemory]].asScala
+ private val offsetMappings =
+ new ConcurrentHashMap[ShuffleId, CopyOnWriteArrayList[UcpMemory]].asScala
+
+ /** Mapper commit protocol extension. Register index and data files and publish all needed
+ * metadata to driver.
+ */
+ def writeIndexFileAndCommitCommon(
+ shuffleId: ShuffleId,
+ mapId: Int,
+ lengths: Array[Long],
+ dataTmp: File,
+ indexBackFile: RandomAccessFile,
+ dataBackFile: RandomAccessFile
+ ): Unit = {
val startTime = System.currentTimeMillis()
fileMappings.putIfAbsent(shuffleId, new CopyOnWriteArrayList[UcpMemory]())
@@ -42,19 +48,26 @@ abstract class CommonUcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffle
val dataFileChannel = dataBackFile.getChannel
// Memory map and register data and index file.
- val dataAddress = UnsafeUtils.mmap(dataFileChannel, 0, dataBackFile.length())
- val memMapParams = new UcpMemMapParams().setAddress(dataAddress)
+ val dataAddress =
+ UnsafeUtils.mmap(dataFileChannel, 0, dataBackFile.length())
+ val memMapParams = new UcpMemMapParams()
+ .setAddress(dataAddress)
.setLength(dataBackFile.length())
if (ucxShuffleManager.ucxShuffleConf.useOdp) {
memMapParams.nonBlocking()
}
- val dataMemory = ucxShuffleManager.ucxNode.getContext.memoryMap(memMapParams)
+ val dataMemory =
+ ucxShuffleManager.ucxNode.getContext.memoryMap(memMapParams)
fileMappings(shuffleId).add(dataMemory)
- assume(indexBackFile.length() == UnsafeUtils.LONG_SIZE * (lengths.length + 1))
+ assume(
+ indexBackFile.length() == UnsafeUtils.LONG_SIZE * (lengths.length + 1)
+ )
- val offsetAddress = UnsafeUtils.mmap(indexFileChannel, 0, indexBackFile.length())
+ val offsetAddress =
+ UnsafeUtils.mmap(indexFileChannel, 0, indexBackFile.length())
memMapParams.setAddress(offsetAddress).setLength(indexBackFile.length())
- val offsetMemory = ucxShuffleManager.ucxNode.getContext.memoryMap(memMapParams)
+ val offsetMemory =
+ ucxShuffleManager.ucxNode.getContext.memoryMap(memMapParams)
offsetMappings(shuffleId).add(offsetMemory)
dataFileChannel.close()
@@ -65,14 +78,19 @@ abstract class CommonUcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffle
val fileMemoryRkey = dataMemory.getRemoteKeyBuffer
val offsetRkey = offsetMemory.getRemoteKeyBuffer
- val metadataRegisteredMemory = memPool.get(
- fileMemoryRkey.capacity() + offsetRkey.capacity() + 24)
+ val metadataRegisteredMemory =
+ memPool.get(fileMemoryRkey.capacity() + offsetRkey.capacity() + 24)
val metadataBuffer = metadataRegisteredMemory.getBuffer.slice()
- if (metadataBuffer.remaining() > ucxShuffleManager.ucxShuffleConf.metadataBlockSize) {
- throw new SparkException(s"Metadata block size ${metadataBuffer.remaining() / 2} " +
- s"is greater then configured ${ucxShuffleManager.ucxShuffleConf.RKEY_SIZE.key}" +
- s"(${ucxShuffleManager.ucxShuffleConf.metadataBlockSize}).")
+ if (
+ metadataBuffer
+ .remaining() > ucxShuffleManager.ucxShuffleConf.metadataBlockSize
+ ) {
+ throw new SparkException(
+ s"Metadata block size ${metadataBuffer.remaining() / 2} " +
+ s"is greater then configured ${ucxShuffleManager.ucxShuffleConf.RKEY_SIZE.key}" +
+ s"(${ucxShuffleManager.ucxShuffleConf.metadataBlockSize})."
+ )
}
metadataBuffer.clear()
@@ -94,16 +112,23 @@ abstract class CommonUcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffle
mapId * ucxShuffleManager.ucxShuffleConf.metadataBlockSize
val driverEndpoint = workerWrapper.driverEndpoint
- val request = driverEndpoint.putNonBlocking(UcxUtils.getAddress(metadataBuffer),
- metadataBuffer.remaining(), driverOffset, driverMetadata.driverRkey, null)
+ val request = driverEndpoint.putNonBlocking(
+ UcxUtils.getAddress(metadataBuffer),
+ metadataBuffer.remaining(),
+ driverOffset,
+ driverMetadata.driverRkey,
+ null
+ )
workerWrapper.preconnect()
// Blocking progress needed to make sure last mapper published data to driver before
// reducer starts.
workerWrapper.waitRequest(request)
memPool.put(metadataRegisteredMemory)
- logInfo(s"MapID: $mapId register files + publishing overhead: " +
- s"${System.currentTimeMillis() - startTime} ms")
+ logInfo(
+ s"MapID: $mapId register files + publishing overhead: " +
+ s"${System.currentTimeMillis() - startTime} ms"
+ )
}
private def unregisterAndUnmap(mem: UcpMemory): Unit = {
@@ -114,10 +139,16 @@ abstract class CommonUcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffle
}
def removeShuffle(shuffleId: Int): Unit = {
- fileMappings.remove(shuffleId).foreach((mappings: CopyOnWriteArrayList[UcpMemory]) =>
- mappings.asScala.par.foreach(unregisterAndUnmap))
- offsetMappings.remove(shuffleId).foreach((mappings: CopyOnWriteArrayList[UcpMemory]) =>
- mappings.asScala.par.foreach(unregisterAndUnmap))
+ fileMappings
+ .remove(shuffleId)
+ .foreach((mappings: CopyOnWriteArrayList[UcpMemory]) =>
+ mappings.asScala.par.foreach(unregisterAndUnmap)
+ )
+ offsetMappings
+ .remove(shuffleId)
+ .foreach((mappings: CopyOnWriteArrayList[UcpMemory]) =>
+ mappings.asScala.par.foreach(unregisterAndUnmap)
+ )
}
override def stop(): Unit = {
diff --git a/src/main/scala/org/apache/spark/shuffle/CommonUcxShuffleManager.scala b/src/main/scala/org/apache/spark/shuffle/CommonUcxShuffleManager.scala
index 05ebad4c..e2a00a77 100755
--- a/src/main/scala/org/apache/spark/shuffle/CommonUcxShuffleManager.scala
+++ b/src/main/scala/org/apache/spark/shuffle/CommonUcxShuffleManager.scala
@@ -1,7 +1,7 @@
/*
-* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
-* See file LICENSE for terms.
-*/
+ * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
+ * See file LICENSE for terms.
+ */
package org.apache.spark.shuffle
import java.util.concurrent.ConcurrentHashMap
@@ -16,10 +16,10 @@ import org.apache.spark.shuffle.ucx.UcxNode
import org.apache.spark.shuffle.ucx.rpc.UcxRemoteMemory
import org.apache.spark.unsafe.Platform
-/**
- * Common part for all spark versions for UcxShuffleManager logic
- */
-abstract class CommonUcxShuffleManager(val conf: SparkConf, isDriver: Boolean) extends SortShuffleManager(conf) {
+/** Common part for all spark versions for UcxShuffleManager logic
+ */
+abstract class CommonUcxShuffleManager(val conf: SparkConf, isDriver: Boolean)
+ extends SortShuffleManager(conf) {
type ShuffleId = Int
type MapId = Int
val ucxShuffleConf = new UcxShuffleConf(conf)
@@ -36,9 +36,11 @@ abstract class CommonUcxShuffleManager(val conf: SparkConf, isDriver: Boolean) e
startUcxNodeIfMissing()
}
- protected def registerShuffleCommon[K, V, C](baseHandle: BaseShuffleHandle[K,V,C],
- shuffleId: ShuffleId,
- numMaps: Int): ShuffleHandle = {
+ protected def registerShuffleCommon[K, V, C](
+ baseHandle: BaseShuffleHandle[K, V, C],
+ shuffleId: ShuffleId,
+ numMaps: Int
+ ): ShuffleHandle = {
// Register metadata buffer where each map will publish it's index and data file metadata
val metadataBufferSize = numMaps * ucxShuffleConf.metadataBlockSize
val metadataBuffer = Platform.allocateDirectBuffer(metadataBufferSize.toInt)
@@ -46,24 +48,25 @@ abstract class CommonUcxShuffleManager(val conf: SparkConf, isDriver: Boolean) e
val metadataMemory = ucxNode.getContext.registerMemory(metadataBuffer)
shuffleIdToMetadataBuffer.put(shuffleId, metadataMemory)
- val driverMemory = new UcxRemoteMemory(metadataMemory.getAddress,
- metadataMemory.getRemoteKeyBuffer)
+ val driverMemory = new UcxRemoteMemory(
+ metadataMemory.getAddress,
+ metadataMemory.getRemoteKeyBuffer
+ )
- val handle = new UcxShuffleHandle(shuffleId, driverMemory, numMaps, baseHandle)
+ val handle =
+ new UcxShuffleHandle(shuffleId, driverMemory, numMaps, baseHandle)
shuffleIdToHandle.putIfAbsent(shuffleId, handle)
handle
}
- /**
- * Mapping between shuffle and metadata buffer, to deregister it when shuffle not needed.
- */
+ /** Mapping between shuffle and metadata buffer, to deregister it when shuffle not needed.
+ */
protected val shuffleIdToMetadataBuffer: mutable.Map[ShuffleId, UcpMemory] =
new ConcurrentHashMap[ShuffleId, UcpMemory]().asScala
- /**
- * Atomically starts UcxNode singleton - one for all shuffle threads.
- */
+ /** Atomically starts UcxNode singleton - one for all shuffle threads.
+ */
def startUcxNodeIfMissing(): Unit = if (ucxNode == null) {
synchronized {
if (ucxNode == null) {
@@ -74,13 +77,14 @@ abstract class CommonUcxShuffleManager(val conf: SparkConf, isDriver: Boolean) e
override def unregisterShuffle(shuffleId: Int): Boolean = {
shuffleIdToMetadataBuffer.remove(shuffleId).foreach(_.deregister())
- shuffleBlockResolver.asInstanceOf[CommonUcxShuffleBlockResolver].removeShuffle(shuffleId)
+ shuffleBlockResolver
+ .asInstanceOf[CommonUcxShuffleBlockResolver]
+ .removeShuffle(shuffleId)
super.unregisterShuffle(shuffleId)
}
- /**
- * Called on both driver and executors to finally cleanup resources.
- */
+ /** Called on both driver and executors to finally cleanup resources.
+ */
override def stop(): Unit = synchronized {
logInfo("Stopping shuffle manager")
shuffleIdToHandle.keys.foreach(unregisterShuffle)
@@ -94,11 +98,12 @@ abstract class CommonUcxShuffleManager(val conf: SparkConf, isDriver: Boolean) e
}
-/**
- * Spark shuffle handles extensions, broadcasted by TCP to executors.
- * Added metadataBufferOnDriver field, that contains address and rkey of driver metadata buffer.
- */
-class UcxShuffleHandle[K, V, C](override val shuffleId: Int,
- val metadataBufferOnDriver: UcxRemoteMemory,
- val numMaps: Int,
- val baseHandle: BaseShuffleHandle[K,V,C]) extends ShuffleHandle(shuffleId)
+/** Spark shuffle handles extensions, broadcasted by TCP to executors.
+ * Added metadataBufferOnDriver field, that contains address and rkey of driver metadata buffer.
+ */
+class UcxShuffleHandle[K, V, C](
+ override val shuffleId: Int,
+ val metadataBufferOnDriver: UcxRemoteMemory,
+ val numMaps: Int,
+ val baseHandle: BaseShuffleHandle[K, V, C]
+) extends ShuffleHandle(shuffleId)
diff --git a/src/main/scala/org/apache/spark/shuffle/UcxShuffleConf.scala b/src/main/scala/org/apache/spark/shuffle/UcxShuffleConf.scala
index 2fc6c6c2..8b7bcb9e 100755
--- a/src/main/scala/org/apache/spark/shuffle/UcxShuffleConf.scala
+++ b/src/main/scala/org/apache/spark/shuffle/UcxShuffleConf.scala
@@ -1,7 +1,7 @@
/*
-* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
-* See file LICENSE for terms.
-*/
+ * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+ * See file LICENSE for terms.
+ */
package org.apache.spark.shuffle
import scala.collection.JavaConverters._
@@ -11,33 +11,34 @@ import org.apache.spark.internal.config.{ConfigBuilder, ConfigEntry}
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.util.Utils
-/**
- * Plugin configuration properties.
- */
+/** Plugin configuration properties.
+ */
class UcxShuffleConf(conf: SparkConf) extends SparkConf {
private def getUcxConf(name: String) = s"spark.shuffle.ucx.$name"
lazy val getNumProcesses: Int = getInt("spark.executor.instances", 1)
- lazy val coresPerProcess: Int = getInt("spark.executor.cores",
- Runtime.getRuntime.availableProcessors())
+ lazy val coresPerProcess: Int =
+ getInt("spark.executor.cores", Runtime.getRuntime.availableProcessors())
- lazy val driverHost: String = conf.get(getUcxConf("driver.host"),
- conf.get("spark.driver.host", "0.0.0.0"))
+ lazy val driverHost: String = conf.get(
+ getUcxConf("driver.host"),
+ conf.get("spark.driver.host", "0.0.0.0")
+ )
lazy val driverPort: Int = conf.getInt(getUcxConf("driver.port"), 55443)
// Metadata
lazy val RKEY_SIZE: ConfigEntry[Long] =
- ConfigBuilder(getUcxConf("rkeySize"))
- .doc("Maximum size of rKeyBuffer")
- .bytesConf(ByteUnit.BYTE)
- .createWithDefault(150)
+ ConfigBuilder(getUcxConf("rkeySize"))
+ .doc("Maximum size of rKeyBuffer")
+ .bytesConf(ByteUnit.BYTE)
+ .createWithDefault(150)
// For metadata we publish index file + data file rkeys
- lazy val metadataBlockSize: Long = 2 * conf.getSizeAsBytes(RKEY_SIZE.key,
- RKEY_SIZE.defaultValueString)
+ lazy val metadataBlockSize: Long =
+ 2 * conf.getSizeAsBytes(RKEY_SIZE.key, RKEY_SIZE.defaultValueString)
private lazy val METADATA_RPC_BUFFER_SIZE =
ConfigBuilder(getUcxConf("rpc.metadata.bufferSize"))
@@ -45,46 +46,71 @@ class UcxShuffleConf(conf: SparkConf) extends SparkConf {
.bytesConf(ByteUnit.BYTE)
.createWithDefault(4096)
- lazy val metadataRPCBufferSize: Int = conf.getSizeAsBytes(METADATA_RPC_BUFFER_SIZE.key,
- METADATA_RPC_BUFFER_SIZE.defaultValueString).toInt
+ lazy val metadataRPCBufferSize: Int = conf
+ .getSizeAsBytes(
+ METADATA_RPC_BUFFER_SIZE.key,
+ METADATA_RPC_BUFFER_SIZE.defaultValueString
+ )
+ .toInt
// Memory Pool
private lazy val PREALLOCATE_BUFFERS =
- ConfigBuilder(getUcxConf("memory.preAllocateBuffers"))
- .doc("Comma separated list of buffer size : buffer count pairs to preallocate in memory pool. E.g. 4k:1000,16k:500")
- .stringConf.createWithDefault("")
-
- lazy val preallocateBuffersMap: java.util.Map[java.lang.Integer, java.lang.Integer] = {
- conf.get(PREALLOCATE_BUFFERS).split(",").withFilter(s => !s.isEmpty)
- .map(entry => entry.split(":") match {
- case Array(bufferSize, bufferCount) =>
- (int2Integer(Utils.byteStringAsBytes(bufferSize.trim).toInt),
- int2Integer(bufferCount.toInt))
- }).toMap.asJava
+ ConfigBuilder(getUcxConf("memory.preAllocateBuffers"))
+ .doc(
+ "Comma separated list of buffer size : buffer count pairs to preallocate in memory pool. E.g. 4k:1000,16k:500"
+ )
+ .stringConf
+ .createWithDefault("")
+
+ lazy val preallocateBuffersMap
+ : java.util.Map[java.lang.Integer, java.lang.Integer] = {
+ conf
+ .get(PREALLOCATE_BUFFERS)
+ .split(",")
+ .withFilter(s => !s.isEmpty)
+ .map(entry =>
+ entry.split(":") match {
+ case Array(bufferSize, bufferCount) =>
+ (
+ int2Integer(Utils.byteStringAsBytes(bufferSize.trim).toInt),
+ int2Integer(bufferCount.toInt)
+ )
+ }
+ )
+ .toMap
+ .asJava
}
- private lazy val MIN_BUFFER_SIZE = ConfigBuilder(getUcxConf("memory.minBufferSize"))
- .doc("Minimal buffer size in memory pool.")
- .bytesConf(ByteUnit.BYTE)
- .createWithDefault(1024)
+ private lazy val MIN_BUFFER_SIZE =
+ ConfigBuilder(getUcxConf("memory.minBufferSize"))
+ .doc("Minimal buffer size in memory pool.")
+ .bytesConf(ByteUnit.BYTE)
+ .createWithDefault(1024)
- lazy val minBufferSize: Long = conf.getSizeAsBytes(MIN_BUFFER_SIZE.key,
- MIN_BUFFER_SIZE.defaultValueString)
+ lazy val minBufferSize: Long =
+ conf.getSizeAsBytes(MIN_BUFFER_SIZE.key, MIN_BUFFER_SIZE.defaultValueString)
private lazy val MIN_REGISTRATION_SIZE =
ConfigBuilder(getUcxConf("memory.minAllocationSize"))
- .doc("Minimal memory registration size in memory pool.")
- .bytesConf(ByteUnit.MiB)
- .createWithDefault(4)
-
- lazy val minRegistrationSize: Int = conf.getSizeAsBytes(MIN_REGISTRATION_SIZE.key,
- MIN_REGISTRATION_SIZE.defaultValueString).toInt
-
- private lazy val PREREGISTER_MEMORY = ConfigBuilder(getUcxConf("memory.preregister"))
- .doc("Whether to do ucp mem map for allocated memory in memory pool")
- .booleanConf.createWithDefault(true)
-
- lazy val preregisterMemory: Boolean = conf.getBoolean(PREREGISTER_MEMORY.key, PREREGISTER_MEMORY.defaultValue.get)
+ .doc("Minimal memory registration size in memory pool.")
+ .bytesConf(ByteUnit.MiB)
+ .createWithDefault(4)
+
+ lazy val minRegistrationSize: Int = conf
+ .getSizeAsBytes(
+ MIN_REGISTRATION_SIZE.key,
+ MIN_REGISTRATION_SIZE.defaultValueString
+ )
+ .toInt
+
+ private lazy val PREREGISTER_MEMORY =
+ ConfigBuilder(getUcxConf("memory.preregister"))
+ .doc("Whether to do ucp mem map for allocated memory in memory pool")
+ .booleanConf
+ .createWithDefault(true)
+
+ lazy val preregisterMemory: Boolean =
+ conf.getBoolean(PREREGISTER_MEMORY.key, PREREGISTER_MEMORY.defaultValue.get)
lazy val useOdp: Boolean = conf.getBoolean(getUcxConf("memory.useOdp"), false)
}
diff --git a/src/main/scala/org/apache/spark/shuffle/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/UcxWorkerWrapper.scala
index 0986a500..ab4ce95f 100755
--- a/src/main/scala/org/apache/spark/shuffle/UcxWorkerWrapper.scala
+++ b/src/main/scala/org/apache/spark/shuffle/UcxWorkerWrapper.scala
@@ -1,7 +1,7 @@
/*
-* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
-* See file LICENSE for terms.
-*/
+ * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+ * See file LICENSE for terms.
+ */
package org.apache.spark.shuffle
import java.io.Closeable
@@ -13,19 +13,28 @@ import scala.collection.JavaConverters._
import scala.collection.mutable
import org.openucx.jucx.UcxException
-import org.openucx.jucx.ucp.{UcpEndpoint, UcpEndpointParams, UcpRemoteKey, UcpRequest, UcpWorker}
+import org.openucx.jucx.ucp.{
+ UcpEndpoint,
+ UcpEndpointParams,
+ UcpRemoteKey,
+ UcpRequest,
+ UcpWorker
+}
import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.shuffle.ucx.{UcxNode, UnsafeUtils}
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.unsafe.Platform
-/**
- * Driver metadata buffer information that holds unpacked RkeyBuffer for this WorkerWrapper
- * and fetched buffer itself.
- */
-case class DriverMetadata(address: Long, driverRkey: UcpRemoteKey, length: Int,
- var data: ByteBuffer) {
+/** Driver metadata buffer information that holds unpacked RkeyBuffer for this WorkerWrapper
+ * and fetched buffer itself.
+ */
+case class DriverMetadata(
+ address: Long,
+ driverRkey: UcpRemoteKey,
+ length: Int,
+ var data: ByteBuffer
+) {
// Driver metadata is an array of blocks:
// | mapId0 | mapId1 | mapId2 | mapId3 | mapId4 | mapId5 |
// Each block in driver metadata has next layout:
@@ -64,64 +73,68 @@ case class DriverMetadata(address: Long, driverRkey: UcpRemoteKey, length: Int,
}
}
-/**
- * Worker per thread wrapper, that maintains connection and progress logic.
- */
-class UcxWorkerWrapper(val worker: UcpWorker, val conf: UcxShuffleConf, val id: Int)
- extends Closeable with Logging {
+/** Worker per thread wrapper, that maintains connection and progress logic.
+ */
+class UcxWorkerWrapper(
+ val worker: UcpWorker,
+ val conf: UcxShuffleConf,
+ val id: Int
+) extends Closeable
+ with Logging {
import UcxWorkerWrapper._
- private final val driverSocketAddress = new InetSocketAddress(conf.driverHost, conf.driverPort)
- private final val endpointParams = new UcpEndpointParams().setSocketAddress(driverSocketAddress)
+ private final val driverSocketAddress =
+ new InetSocketAddress(conf.driverHost, conf.driverPort)
+ private final val endpointParams = new UcpEndpointParams()
+ .setSocketAddress(driverSocketAddress)
.setPeerErrorHandlingMode()
val driverEndpoint: UcpEndpoint = worker.newEndpoint(endpointParams)
private final val connections = mutable.Map.empty[BlockManagerId, UcpEndpoint]
- private final val driverMetadata = mutable.Map.empty[ShuffleId, DriverMetadata]
+ private final val driverMetadata =
+ mutable.Map.empty[ShuffleId, DriverMetadata]
override def close(): Unit = {
- driverMetadata.values.foreach{
+ driverMetadata.values.foreach {
case DriverMetadata(address, rkey, length, data) => rkey.close()
}
driverMetadata.clear()
driverEndpoint.close()
- connections.foreach{
- case (_, endpoint) => endpoint.close()
+ connections.foreach { case (_, endpoint) =>
+ endpoint.close()
}
connections.clear()
worker.close()
driverMetadataBuffer.clear()
}
- /**
- * Blocking progress single request until it's not completed.
- */
+ /** Blocking progress single request until it's not completed.
+ */
def waitRequest(request: UcpRequest): Unit = {
val startTime = System.currentTimeMillis()
worker.progressRequest(request)
- logDebug(s"Request completed in ${System.currentTimeMillis() - startTime} ms")
+ logDebug(
+ s"Request completed in ${System.currentTimeMillis() - startTime} ms"
+ )
}
- /**
- * Blocking progress while result queue is empty.
- */
+ /** Blocking progress while result queue is empty.
+ */
def fillQueueWithBlocks(queue: LinkedBlockingQueue[_]): Unit = {
while (queue.isEmpty) {
progress()
}
}
- /**
- * The only place for worker progress
- */
+ /** The only place for worker progress
+ */
private def progress(): Int = {
worker.progress()
}
- /**
- * Establish connections to known instances.
- */
+ /** Establish connections to known instances.
+ */
def preconnect(): Unit = {
UcxNode.getWorkerAddresses.keySet().asScala.foreach(getConnection)
}
@@ -136,61 +149,74 @@ class UcxWorkerWrapper(val worker: UcpWorker, val conf: UcxShuffleConf, val id:
while (workerAddresses.get(blockManagerId) == null) {
workerAddresses.wait(timeout)
if (System.currentTimeMillis() - startTime > timeout) {
- throw new UcxException(s"Didn't get worker address for $blockManagerId during $timeout")
+ throw new UcxException(
+ s"Didn't get worker address for $blockManagerId during $timeout"
+ )
}
}
}
}
- connections.getOrElseUpdate(blockManagerId, {
- logInfo(s"Worker $id connecting to $blockManagerId")
- val endpointParams = new UcpEndpointParams()
- .setPeerErrorHandlingMode()
- .setUcpAddress(workerAddresses.get(blockManagerId))
- worker.newEndpoint(endpointParams)
- })
+ connections.getOrElseUpdate(
+ blockManagerId, {
+ logInfo(s"Worker $id connecting to $blockManagerId")
+ val endpointParams = new UcpEndpointParams()
+ .setPeerErrorHandlingMode()
+ .setUcpAddress(workerAddresses.get(blockManagerId))
+ worker.newEndpoint(endpointParams)
+ }
+ )
}
- /**
- * Unpacks driver metadata RkeyBuffer for this worker.
- * Needed to perform PUT operation to publish map output info.
- */
+ /** Unpacks driver metadata RkeyBuffer for this worker.
+ * Needed to perform PUT operation to publish map output info.
+ */
def getDriverMetadata(shuffleId: ShuffleId): DriverMetadata = {
- driverMetadata.getOrElseUpdate(shuffleId, {
- val ucxShuffleHandle = SparkEnv.get.shuffleManager.asInstanceOf[CommonUcxShuffleManager]
- .shuffleIdToHandle(shuffleId)
- val (address, length, rkey): (Long, Int, ByteBuffer) = (ucxShuffleHandle.metadataBufferOnDriver.getAddress,
- ucxShuffleHandle.numMaps * conf.metadataBlockSize.toInt,
- ucxShuffleHandle.metadataBufferOnDriver.getRkeyBuffer)
-
- rkey.clear()
- val unpackedRkey = driverEndpoint.unpackRemoteKey(rkey)
- DriverMetadata(address, unpackedRkey, length, null)
- })
+ driverMetadata.getOrElseUpdate(
+ shuffleId, {
+ val ucxShuffleHandle = SparkEnv.get.shuffleManager
+ .asInstanceOf[CommonUcxShuffleManager]
+ .shuffleIdToHandle(shuffleId)
+ val (address, length, rkey): (Long, Int, ByteBuffer) = (
+ ucxShuffleHandle.metadataBufferOnDriver.getAddress,
+ ucxShuffleHandle.numMaps * conf.metadataBlockSize.toInt,
+ ucxShuffleHandle.metadataBufferOnDriver.getRkeyBuffer
+ )
+
+ rkey.clear()
+ val unpackedRkey = driverEndpoint.unpackRemoteKey(rkey)
+ DriverMetadata(address, unpackedRkey, length, null)
+ }
+ )
}
- /**
- * Fetches using ucp_get metadata buffer from driver, with all needed information
- * for offset and data addresses and keys.
- */
+ /** Fetches using ucp_get metadata buffer from driver, with all needed information
+ * for offset and data addresses and keys.
+ */
def fetchDriverMetadataBuffer(shuffleId: ShuffleId): DriverMetadata = {
- val handle = SparkEnv.get.shuffleManager.asInstanceOf[CommonUcxShuffleManager]
+ val handle = SparkEnv.get.shuffleManager
+ .asInstanceOf[CommonUcxShuffleManager]
.shuffleIdToHandle(shuffleId)
val metadata = getDriverMetadata(handle.shuffleId)
- UcxWorkerWrapper.driverMetadataBuffer.computeIfAbsent(shuffleId,
+ UcxWorkerWrapper.driverMetadataBuffer.computeIfAbsent(
+ shuffleId,
(t: ShuffleId) => {
val buffer = Platform.allocateDirectBuffer(metadata.length)
val request = driverEndpoint.getNonBlocking(
- metadata.address, metadata.driverRkey, buffer, null)
+ metadata.address,
+ metadata.driverRkey,
+ buffer,
+ null
+ )
waitRequest(request)
buffer
}
)
if (metadata.data == null) {
- metadata.data = UcxWorkerWrapper.driverMetadataBuffer.get(shuffleId)
+ metadata.data = UcxWorkerWrapper.driverMetadataBuffer.get(shuffleId)
}
metadata
}
@@ -203,5 +229,9 @@ object UcxWorkerWrapper {
val driverMetadataBuffer = new ConcurrentHashMap[ShuffleId, ByteBuffer]()
val metadataBlockSize: MapId =
- SparkEnv.get.shuffleManager.asInstanceOf[CommonUcxShuffleManager].ucxShuffleConf.metadataBlockSize.toInt
+ SparkEnv.get.shuffleManager
+ .asInstanceOf[CommonUcxShuffleManager]
+ .ucxShuffleConf
+ .metadataBlockSize
+ .toInt
}
diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleBlockResolver.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleBlockResolver.scala
index 90084f39..d28db889 100755
--- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleBlockResolver.scala
+++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleBlockResolver.scala
@@ -1,33 +1,44 @@
/*
-* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
-* See file LICENSE for terms.
-*/
+ * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+ * See file LICENSE for terms.
+ */
package org.apache.spark.shuffle.compat.spark_2_1
import java.io.{File, RandomAccessFile}
import org.apache.spark.SparkEnv
-import org.apache.spark.shuffle.{CommonUcxShuffleBlockResolver, CommonUcxShuffleManager, IndexShuffleBlockResolver}
+import org.apache.spark.shuffle.{
+ CommonUcxShuffleBlockResolver,
+ CommonUcxShuffleManager,
+ IndexShuffleBlockResolver
+}
import org.apache.spark.storage.ShuffleIndexBlockId
-/**
- * Mapper entry point for UcxShuffle plugin. Performs memory registration
- * of data and index files and publish addresses to driver metadata buffer.
- */
+/** Mapper entry point for UcxShuffle plugin. Performs memory registration
+ * of data and index files and publish addresses to driver metadata buffer.
+ */
class UcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager)
- extends CommonUcxShuffleBlockResolver(ucxShuffleManager) {
+ extends CommonUcxShuffleBlockResolver(ucxShuffleManager) {
private def getIndexFile(shuffleId: Int, mapId: Int): File = {
- SparkEnv.get.blockManager
- .diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID))
+ SparkEnv.get.blockManager.diskBlockManager.getFile(
+ ShuffleIndexBlockId(
+ shuffleId,
+ mapId,
+ IndexShuffleBlockResolver.NOOP_REDUCE_ID
+ )
+ )
}
- /**
- * Mapper commit protocol extension. Register index and data files and publish all needed
- * metadata to driver.
- */
- override def writeIndexFileAndCommit(shuffleId: ShuffleId, mapId: Int,
- lengths: Array[Long], dataTmp: File): Unit = {
+ /** Mapper commit protocol extension. Register index and data files and publish all needed
+ * metadata to driver.
+ */
+ override def writeIndexFileAndCommit(
+ shuffleId: ShuffleId,
+ mapId: Int,
+ lengths: Array[Long],
+ dataTmp: File
+ ): Unit = {
super.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp)
val dataFile = getDataFile(shuffleId, mapId)
val dataBackFile = new RandomAccessFile(dataFile, "rw")
@@ -39,6 +50,13 @@ class UcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager)
val indexFile = getIndexFile(shuffleId, mapId)
val indexBackFile = new RandomAccessFile(indexFile, "rw")
- writeIndexFileAndCommitCommon(shuffleId, mapId, lengths, dataTmp, indexBackFile, dataBackFile)
+ writeIndexFileAndCommitCommon(
+ shuffleId,
+ mapId,
+ lengths,
+ dataTmp,
+ indexBackFile,
+ dataBackFile
+ )
}
}
diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleManager.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleManager.scala
index 7087e491..b209f763 100755
--- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleManager.scala
+++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleManager.scala
@@ -1,53 +1,78 @@
/*
-* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
-* See file LICENSE for terms.
-*/
+ * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+ * See file LICENSE for terms.
+ */
package org.apache.spark.shuffle
-import org.apache.spark.shuffle.compat.spark_2_1.{UcxShuffleBlockResolver, UcxShuffleReader}
+import org.apache.spark.shuffle.compat.spark_2_1.{
+ UcxShuffleBlockResolver,
+ UcxShuffleReader
+}
import org.apache.spark.util.ShutdownHookManager
import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext}
-/**
- * Main entry point of Ucx shuffle plugin. It extends spark's default SortShufflePlugin
- * and injects needed logic in override methods.
- */
-class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean) extends CommonUcxShuffleManager(conf, isDriver) {
+/** Main entry point of Ucx shuffle plugin. It extends spark's default SortShufflePlugin
+ * and injects needed logic in override methods.
+ */
+class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean)
+ extends CommonUcxShuffleManager(conf, isDriver) {
ShutdownHookManager.addShutdownHook(Int.MaxValue - 1)(stop)
- /**
- * Register a shuffle with the manager and obtain a handle for it to pass to tasks.
- * Called on driver and guaranteed by spark that shuffle on executor will start after it.
- */
- override def registerShuffle[K, V, C](shuffleId: ShuffleId,
- numMaps: Int,
- dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+ /** Register a shuffle with the manager and obtain a handle for it to pass to tasks.
+ * Called on driver and guaranteed by spark that shuffle on executor will start after it.
+ */
+ override def registerShuffle[K, V, C](
+ shuffleId: ShuffleId,
+ numMaps: Int,
+ dependency: ShuffleDependency[K, V, C]
+ ): ShuffleHandle = {
assume(isDriver)
- val baseHandle = super.registerShuffle(shuffleId, numMaps, dependency).asInstanceOf[BaseShuffleHandle[K, V, C]]
+ val baseHandle = super
+ .registerShuffle(shuffleId, numMaps, dependency)
+ .asInstanceOf[BaseShuffleHandle[K, V, C]]
registerShuffleCommon(baseHandle, shuffleId, numMaps)
}
- /**
- * Mapper callback on executor. Just start UcxNode and use Spark mapper logic.
- */
- override def getWriter[K, V](handle: ShuffleHandle, mapId: Int,
- context: TaskContext): ShuffleWriter[K, V] = {
+ /** Mapper callback on executor. Just start UcxNode and use Spark mapper logic.
+ */
+ override def getWriter[K, V](
+ handle: ShuffleHandle,
+ mapId: Int,
+ context: TaskContext
+ ): ShuffleWriter[K, V] = {
startUcxNodeIfMissing()
- shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K,V,_]])
- super.getWriter(handle.asInstanceOf[UcxShuffleHandle[K,V,_]].baseHandle, mapId, context)
+ shuffleIdToHandle.putIfAbsent(
+ handle.shuffleId,
+ handle.asInstanceOf[UcxShuffleHandle[K, V, _]]
+ )
+ super.getWriter(
+ handle.asInstanceOf[UcxShuffleHandle[K, V, _]].baseHandle,
+ mapId,
+ context
+ )
}
- override val shuffleBlockResolver: UcxShuffleBlockResolver = new UcxShuffleBlockResolver(this)
+ override val shuffleBlockResolver: UcxShuffleBlockResolver =
+ new UcxShuffleBlockResolver(this)
- /**
- * Reducer callback on executor.
- */
- override def getReader[K, C](handle: ShuffleHandle, startPartition: Int,
- endPartition: Int, context: TaskContext): ShuffleReader[K, C] = {
+ /** Reducer callback on executor.
+ */
+ override def getReader[K, C](
+ handle: ShuffleHandle,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext
+ ): ShuffleReader[K, C] = {
startUcxNodeIfMissing()
- shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K,_,C]])
- new UcxShuffleReader(handle.asInstanceOf[UcxShuffleHandle[K,_,C]], startPartition,
- endPartition, context)
+ shuffleIdToHandle.putIfAbsent(
+ handle.shuffleId,
+ handle.asInstanceOf[UcxShuffleHandle[K, _, C]]
+ )
+ new UcxShuffleReader(
+ handle.asInstanceOf[UcxShuffleHandle[K, _, C]],
+ startPartition,
+ endPartition,
+ context
+ )
}
}
-
diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleReader.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleReader.scala
index 1c7e1511..69b52886 100755
--- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleReader.scala
+++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleReader.scala
@@ -1,7 +1,7 @@
/*
-* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
-* See file LICENSE for terms.
-*/
+ * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+ * See file LICENSE for terms.
+ */
package org.apache.spark.shuffle.compat.spark_2_1
import java.io.InputStream
@@ -10,130 +10,164 @@ import java.util.concurrent.LinkedBlockingQueue
import org.apache.spark.internal.{Logging, config}
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.shuffle.ucx.reducer.compat.spark_2_1.UcxShuffleClient
-import org.apache.spark.shuffle.{ShuffleReader, UcxShuffleHandle, UcxShuffleManager}
-import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator}
+import org.apache.spark.shuffle.{
+ ShuffleReader,
+ UcxShuffleHandle,
+ UcxShuffleManager
+}
+import org.apache.spark.storage.{
+ BlockId,
+ BlockManager,
+ ShuffleBlockFetcherIterator
+}
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter
-import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext}
+import org.apache.spark.{
+ InterruptibleIterator,
+ MapOutputTracker,
+ SparkEnv,
+ TaskContext
+}
-/**
- * Extension of Spark's shuffe reader with a logic of injection UcxShuffleClient,
- * and lazy progress only when result queue is empty.
- */
-class UcxShuffleReader[K, C](handle: UcxShuffleHandle[K, _, C],
- startPartition: Int,
- endPartition: Int,
- context: TaskContext,
- serializerManager: SerializerManager = SparkEnv.get.serializerManager,
- blockManager: BlockManager = SparkEnv.get.blockManager,
- mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
- extends ShuffleReader[K, C] with Logging {
+/** Extension of Spark's shuffe reader with a logic of injection UcxShuffleClient,
+ * and lazy progress only when result queue is empty.
+ */
+class UcxShuffleReader[K, C](
+ handle: UcxShuffleHandle[K, _, C],
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ serializerManager: SerializerManager = SparkEnv.get.serializerManager,
+ blockManager: BlockManager = SparkEnv.get.blockManager,
+ mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker
+) extends ShuffleReader[K, C]
+ with Logging {
- private val dep = handle.baseHandle.dependency
+ private val dep = handle.baseHandle.dependency
- /** Read the combined key-values for this reduce task */
- override def read(): Iterator[Product2[K, C]] = {
- val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics()
- val workerWrapper = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager]
- .ucxNode.getThreadLocalWorker
- val shuffleClient = new UcxShuffleClient(shuffleMetrics, workerWrapper)
- val wrappedStreams = new ShuffleBlockFetcherIterator(
- context,
- shuffleClient,
- blockManager,
- mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId,
- startPartition, endPartition),
- // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
- SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
- SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))
+ /** Read the combined key-values for this reduce task */
+ override def read(): Iterator[Product2[K, C]] = {
+ val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics()
+ val workerWrapper = SparkEnv.get.shuffleManager
+ .asInstanceOf[UcxShuffleManager]
+ .ucxNode
+ .getThreadLocalWorker
+ val shuffleClient = new UcxShuffleClient(shuffleMetrics, workerWrapper)
+ val wrappedStreams = new ShuffleBlockFetcherIterator(
+ context,
+ shuffleClient,
+ blockManager,
+ mapOutputTracker.getMapSizesByExecutorId(
+ handle.shuffleId,
+ startPartition,
+ endPartition
+ ),
+ // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
+ SparkEnv.get.conf
+ .getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
+ SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue)
+ )
- // Ucx shuffle logic
- // Java reflection to get access to private results queue
- val queueField = wrappedStreams.getClass.getDeclaredField(
- "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results")
- queueField.setAccessible(true)
- val resultQueue = queueField.get(wrappedStreams).asInstanceOf[LinkedBlockingQueue[_]]
+ // Ucx shuffle logic
+ // Java reflection to get access to private results queue
+ val queueField = wrappedStreams.getClass.getDeclaredField(
+ "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results"
+ )
+ queueField.setAccessible(true)
+ val resultQueue =
+ queueField.get(wrappedStreams).asInstanceOf[LinkedBlockingQueue[_]]
- // Do progress if queue is empty before calling next on ShuffleIterator
- val ucxWrappedStream = new Iterator[(BlockId, InputStream)] {
- override def next(): (BlockId, InputStream) = {
- val startTime = System.currentTimeMillis()
- workerWrapper.fillQueueWithBlocks(resultQueue)
- shuffleMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime)
- wrappedStreams.next()
- }
+ // Do progress if queue is empty before calling next on ShuffleIterator
+ val ucxWrappedStream = new Iterator[(BlockId, InputStream)] {
+ override def next(): (BlockId, InputStream) = {
+ val startTime = System.currentTimeMillis()
+ workerWrapper.fillQueueWithBlocks(resultQueue)
+ shuffleMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime)
+ wrappedStreams.next()
+ }
- override def hasNext: Boolean = {
- val result = wrappedStreams.hasNext
- if (!result) {
- shuffleClient.close()
- }
- result
+ override def hasNext: Boolean = {
+ val result = wrappedStreams.hasNext
+ if (!result) {
+ shuffleClient.close()
}
+ result
}
- // End of ucx shuffle logic
+ }
+ // End of ucx shuffle logic
- val serializerInstance = dep.serializer.newInstance()
- val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) =>
- // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
- // NextIterator. The NextIterator makes sure that close() is called on the
- // underlying InputStream when all records have been read.
- serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
- }
+ val serializerInstance = dep.serializer.newInstance()
+ val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) =>
+ // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
+ // NextIterator. The NextIterator makes sure that close() is called on the
+ // underlying InputStream when all records have been read.
+ serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
+ }
- // Update the context task metrics for each record read.
- val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
- val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
- recordIter.map { record =>
- readMetrics.incRecordsRead(1)
- record
- },
- context.taskMetrics().mergeShuffleReadMetrics())
+ // Update the context task metrics for each record read.
+ val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
+ val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
+ recordIter.map { record =>
+ readMetrics.incRecordsRead(1)
+ record
+ },
+ context.taskMetrics().mergeShuffleReadMetrics()
+ )
- // An interruptible iterator must be used here in order to support task cancellation
- val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
+ // An interruptible iterator must be used here in order to support task cancellation
+ val interruptibleIter =
+ new InterruptibleIterator[(Any, Any)](context, metricIter)
- val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
+ val aggregatedIter: Iterator[Product2[K, C]] =
+ if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
// We are reading values that are already combined
- val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
- dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
+ val combinedKeyValuesIterator =
+ interruptibleIter.asInstanceOf[Iterator[(K, C)]]
+ dep.aggregator.get
+ .combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
// We don't know the value type, but also don't care -- the dependency *should*
// have made sure its compatible w/ this aggregator, which will convert the value
// type to the combined type C
- val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
+ val keyValuesIterator =
+ interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}
- // Sort the output if there is a sort ordering defined.
- val resultIter = dep.keyOrdering match {
- case Some(keyOrd: Ordering[K]) =>
- // Create an ExternalSorter to sort the data.
- val sorter =
- new ExternalSorter[K, C, C](context,
- ordering = Some(keyOrd), serializer = dep.serializer)
- sorter.insertAll(aggregatedIter)
- context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
- context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
- context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
- // Use completion callback to stop sorter if task was finished/cancelled.
- CompletionIterator[Product2[K, C],
- Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
- case None =>
- aggregatedIter
- }
+ // Sort the output if there is a sort ordering defined.
+ val resultIter = dep.keyOrdering match {
+ case Some(keyOrd: Ordering[K]) =>
+ // Create an ExternalSorter to sort the data.
+ val sorter =
+ new ExternalSorter[K, C, C](
+ context,
+ ordering = Some(keyOrd),
+ serializer = dep.serializer
+ )
+ sorter.insertAll(aggregatedIter)
+ context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
+ context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
+ context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
+ // Use completion callback to stop sorter if task was finished/cancelled.
+ CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](
+ sorter.iterator,
+ sorter.stop()
+ )
+ case None =>
+ aggregatedIter
+ }
- resultIter match {
- case _: InterruptibleIterator[Product2[K, C]] => resultIter
- case _ =>
- // Use another interruptible iterator here to support task cancellation as aggregator
- // or(and) sorter may have consumed previous interruptible iterator.
- new InterruptibleIterator[Product2[K, C]](context, resultIter)
- }
+ resultIter match {
+ case _: InterruptibleIterator[Product2[K, C]] => resultIter
+ case _ =>
+ // Use another interruptible iterator here to support task cancellation as aggregator
+ // or(and) sorter may have consumed previous interruptible iterator.
+ new InterruptibleIterator[Product2[K, C]](context, resultIter)
}
+ }
}
diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleBlockResolver.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleBlockResolver.scala
index a47b0af6..6311b5e0 100755
--- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleBlockResolver.scala
+++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleBlockResolver.scala
@@ -1,33 +1,44 @@
/*
-* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
-* See file LICENSE for terms.
-*/
+ * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+ * See file LICENSE for terms.
+ */
package org.apache.spark.shuffle.compat.spark_2_4
import java.io.{File, RandomAccessFile}
import org.apache.spark.SparkEnv
-import org.apache.spark.shuffle.{CommonUcxShuffleBlockResolver, CommonUcxShuffleManager, IndexShuffleBlockResolver}
+import org.apache.spark.shuffle.{
+ CommonUcxShuffleBlockResolver,
+ CommonUcxShuffleManager,
+ IndexShuffleBlockResolver
+}
import org.apache.spark.storage.ShuffleIndexBlockId
-/**
- * Mapper entry point for UcxShuffle plugin. Performs memory registration
- * of data and index files and publish addresses to driver metadata buffer.
- */
+/** Mapper entry point for UcxShuffle plugin. Performs memory registration
+ * of data and index files and publish addresses to driver metadata buffer.
+ */
class UcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager)
- extends CommonUcxShuffleBlockResolver(ucxShuffleManager) {
+ extends CommonUcxShuffleBlockResolver(ucxShuffleManager) {
private def getIndexFile(shuffleId: Int, mapId: Int): File = {
- SparkEnv.get.blockManager
- .diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID))
+ SparkEnv.get.blockManager.diskBlockManager.getFile(
+ ShuffleIndexBlockId(
+ shuffleId,
+ mapId,
+ IndexShuffleBlockResolver.NOOP_REDUCE_ID
+ )
+ )
}
- /**
- * Mapper commit protocol extension. Register index and data files and publish all needed
- * metadata to driver.
- */
- override def writeIndexFileAndCommit(shuffleId: ShuffleId, mapId: Int,
- lengths: Array[Long], dataTmp: File): Unit = {
+ /** Mapper commit protocol extension. Register index and data files and publish all needed
+ * metadata to driver.
+ */
+ override def writeIndexFileAndCommit(
+ shuffleId: ShuffleId,
+ mapId: Int,
+ lengths: Array[Long],
+ dataTmp: File
+ ): Unit = {
super.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp)
val dataFile = getDataFile(shuffleId, mapId)
val dataBackFile = new RandomAccessFile(dataFile, "rw")
@@ -39,6 +50,13 @@ class UcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager)
val indexFile = getIndexFile(shuffleId, mapId)
val indexBackFile = new RandomAccessFile(indexFile, "rw")
- writeIndexFileAndCommitCommon(shuffleId, mapId, lengths, dataTmp, indexBackFile, dataBackFile)
+ writeIndexFileAndCommitCommon(
+ shuffleId,
+ mapId,
+ lengths,
+ dataTmp,
+ indexBackFile,
+ dataBackFile
+ )
}
}
diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleManager.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleManager.scala
index df72287d..841cc141 100755
--- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleManager.scala
+++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleManager.scala
@@ -1,53 +1,78 @@
/*
-* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
-* See file LICENSE for terms.
-*/
+ * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+ * See file LICENSE for terms.
+ */
package org.apache.spark.shuffle
-import org.apache.spark.shuffle.compat.spark_2_4.{UcxShuffleBlockResolver, UcxShuffleReader}
+import org.apache.spark.shuffle.compat.spark_2_4.{
+ UcxShuffleBlockResolver,
+ UcxShuffleReader
+}
import org.apache.spark.util.ShutdownHookManager
import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext}
-/**
- * Main entry point of Ucx shuffle plugin. It extends spark's default SortShufflePlugin
- * and injects needed logic in override methods.
- */
-class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean) extends CommonUcxShuffleManager(conf, isDriver) {
+/** Main entry point of Ucx shuffle plugin. It extends spark's default SortShufflePlugin
+ * and injects needed logic in override methods.
+ */
+class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean)
+ extends CommonUcxShuffleManager(conf, isDriver) {
ShutdownHookManager.addShutdownHook(Int.MaxValue - 1)(stop)
- /**
- * Register a shuffle with the manager and obtain a handle for it to pass to tasks.
- * Called on driver and guaranteed by spark that shuffle on executor will start after it.
- */
- override def registerShuffle[K, V, C](shuffleId: ShuffleId,
- numMaps: Int,
- dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+ /** Register a shuffle with the manager and obtain a handle for it to pass to tasks.
+ * Called on driver and guaranteed by spark that shuffle on executor will start after it.
+ */
+ override def registerShuffle[K, V, C](
+ shuffleId: ShuffleId,
+ numMaps: Int,
+ dependency: ShuffleDependency[K, V, C]
+ ): ShuffleHandle = {
assume(isDriver)
- val baseHandle = super.registerShuffle(shuffleId, numMaps, dependency).asInstanceOf[BaseShuffleHandle[K, V, C]]
+ val baseHandle = super
+ .registerShuffle(shuffleId, numMaps, dependency)
+ .asInstanceOf[BaseShuffleHandle[K, V, C]]
registerShuffleCommon(baseHandle, shuffleId, numMaps)
}
- /**
- * Mapper callback on executor. Just start UcxNode and use Spark mapper logic.
- */
- override def getWriter[K, V](handle: ShuffleHandle, mapId: Int,
- context: TaskContext): ShuffleWriter[K, V] = {
+ /** Mapper callback on executor. Just start UcxNode and use Spark mapper logic.
+ */
+ override def getWriter[K, V](
+ handle: ShuffleHandle,
+ mapId: Int,
+ context: TaskContext
+ ): ShuffleWriter[K, V] = {
startUcxNodeIfMissing()
- shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K,V,_]])
- super.getWriter(handle.asInstanceOf[UcxShuffleHandle[K,V,_]].baseHandle, mapId, context)
+ shuffleIdToHandle.putIfAbsent(
+ handle.shuffleId,
+ handle.asInstanceOf[UcxShuffleHandle[K, V, _]]
+ )
+ super.getWriter(
+ handle.asInstanceOf[UcxShuffleHandle[K, V, _]].baseHandle,
+ mapId,
+ context
+ )
}
- override val shuffleBlockResolver: UcxShuffleBlockResolver = new UcxShuffleBlockResolver(this)
+ override val shuffleBlockResolver: UcxShuffleBlockResolver =
+ new UcxShuffleBlockResolver(this)
- /**
- * Reducer callback on executor.
- */
- override def getReader[K, C](handle: ShuffleHandle, startPartition: Int,
- endPartition: Int, context: TaskContext): ShuffleReader[K, C] = {
+ /** Reducer callback on executor.
+ */
+ override def getReader[K, C](
+ handle: ShuffleHandle,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext
+ ): ShuffleReader[K, C] = {
startUcxNodeIfMissing()
- shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K,_,C]])
- new UcxShuffleReader(handle.asInstanceOf[UcxShuffleHandle[K,_,C]], startPartition,
- endPartition, context)
+ shuffleIdToHandle.putIfAbsent(
+ handle.shuffleId,
+ handle.asInstanceOf[UcxShuffleHandle[K, _, C]]
+ )
+ new UcxShuffleReader(
+ handle.asInstanceOf[UcxShuffleHandle[K, _, C]],
+ startPartition,
+ endPartition,
+ context
+ )
}
}
-
diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleReader.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleReader.scala
index a1ad2294..474c4988 100755
--- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleReader.scala
+++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleReader.scala
@@ -1,146 +1,180 @@
/*
-* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
-* See file LICENSE for terms.
-*/
+ * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+ * See file LICENSE for terms.
+ */
package org.apache.spark.shuffle.compat.spark_2_4
import java.io.InputStream
import java.util.concurrent.LinkedBlockingQueue
-import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext}
+import org.apache.spark.{
+ InterruptibleIterator,
+ MapOutputTracker,
+ SparkEnv,
+ TaskContext
+}
import org.apache.spark.internal.{Logging, config}
import org.apache.spark.serializer.SerializerManager
-import org.apache.spark.shuffle.{ShuffleReader, UcxShuffleHandle, UcxShuffleManager}
+import org.apache.spark.shuffle.{
+ ShuffleReader,
+ UcxShuffleHandle,
+ UcxShuffleManager
+}
import org.apache.spark.shuffle.ucx.reducer.compat.spark_2_4.UcxShuffleClient
-import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator}
+import org.apache.spark.storage.{
+ BlockId,
+ BlockManager,
+ ShuffleBlockFetcherIterator
+}
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter
-/**
- * Extension of Spark's shuffe reader with a logic of injection UcxShuffleClient,
- * and lazy progress only when result queue is empty.
- */
-class UcxShuffleReader[K, C](handle: UcxShuffleHandle[K, _, C],
- startPartition: Int,
- endPartition: Int,
- context: TaskContext,
- serializerManager: SerializerManager = SparkEnv.get.serializerManager,
- blockManager: BlockManager = SparkEnv.get.blockManager,
- mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
- extends ShuffleReader[K, C] with Logging {
+/** Extension of Spark's shuffe reader with a logic of injection UcxShuffleClient,
+ * and lazy progress only when result queue is empty.
+ */
+class UcxShuffleReader[K, C](
+ handle: UcxShuffleHandle[K, _, C],
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ serializerManager: SerializerManager = SparkEnv.get.serializerManager,
+ blockManager: BlockManager = SparkEnv.get.blockManager,
+ mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker
+) extends ShuffleReader[K, C]
+ with Logging {
- private val dep = handle.baseHandle.dependency
+ private val dep = handle.baseHandle.dependency
- /** Read the combined key-values for this reduce task */
- override def read(): Iterator[Product2[K, C]] = {
- val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics()
- val workerWrapper = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager]
- .ucxNode.getThreadLocalWorker
- val shuffleClient = new UcxShuffleClient(shuffleMetrics, workerWrapper)
- val wrappedStreams = new ShuffleBlockFetcherIterator(
- context,
- shuffleClient,
- blockManager,
- mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId,
- startPartition, endPartition),
- serializerManager.wrapStream,
- // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
- SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
- SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
- SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
- SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
- SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
+ /** Read the combined key-values for this reduce task */
+ override def read(): Iterator[Product2[K, C]] = {
+ val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics()
+ val workerWrapper = SparkEnv.get.shuffleManager
+ .asInstanceOf[UcxShuffleManager]
+ .ucxNode
+ .getThreadLocalWorker
+ val shuffleClient = new UcxShuffleClient(shuffleMetrics, workerWrapper)
+ val wrappedStreams = new ShuffleBlockFetcherIterator(
+ context,
+ shuffleClient,
+ blockManager,
+ mapOutputTracker.getMapSizesByExecutorId(
+ handle.shuffleId,
+ startPartition,
+ endPartition
+ ),
+ serializerManager.wrapStream,
+ // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
+ SparkEnv.get.conf
+ .getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
+ SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
+ SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
+ SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
+ SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true)
+ )
- // Ucx shuffle logic
- // Java reflection to get access to private results queue
- val queueField = wrappedStreams.getClass.getDeclaredField(
- "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results")
- queueField.setAccessible(true)
- val resultQueue = queueField.get(wrappedStreams).asInstanceOf[LinkedBlockingQueue[_]]
+ // Ucx shuffle logic
+ // Java reflection to get access to private results queue
+ val queueField = wrappedStreams.getClass.getDeclaredField(
+ "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results"
+ )
+ queueField.setAccessible(true)
+ val resultQueue =
+ queueField.get(wrappedStreams).asInstanceOf[LinkedBlockingQueue[_]]
- // Do progress if queue is empty before calling next on ShuffleIterator
- val ucxWrappedStream = new Iterator[(BlockId, InputStream)] {
- override def next(): (BlockId, InputStream) = {
- val startTime = System.currentTimeMillis()
- workerWrapper.fillQueueWithBlocks(resultQueue)
- shuffleMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime)
- wrappedStreams.next()
- }
+ // Do progress if queue is empty before calling next on ShuffleIterator
+ val ucxWrappedStream = new Iterator[(BlockId, InputStream)] {
+ override def next(): (BlockId, InputStream) = {
+ val startTime = System.currentTimeMillis()
+ workerWrapper.fillQueueWithBlocks(resultQueue)
+ shuffleMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime)
+ wrappedStreams.next()
+ }
- override def hasNext: Boolean = {
- val result = wrappedStreams.hasNext
- if (!result) {
- shuffleClient.close()
- }
- result
+ override def hasNext: Boolean = {
+ val result = wrappedStreams.hasNext
+ if (!result) {
+ shuffleClient.close()
}
+ result
}
- // End of ucx shuffle logic
+ }
+ // End of ucx shuffle logic
- val serializerInstance = dep.serializer.newInstance()
- val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) =>
- // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
- // NextIterator. The NextIterator makes sure that close() is called on the
- // underlying InputStream when all records have been read.
- serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
- }
+ val serializerInstance = dep.serializer.newInstance()
+ val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) =>
+ // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
+ // NextIterator. The NextIterator makes sure that close() is called on the
+ // underlying InputStream when all records have been read.
+ serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
+ }
- // Update the context task metrics for each record read.
- val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
- val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
- recordIter.map { record =>
- readMetrics.incRecordsRead(1)
- record
- },
- context.taskMetrics().mergeShuffleReadMetrics())
+ // Update the context task metrics for each record read.
+ val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
+ val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
+ recordIter.map { record =>
+ readMetrics.incRecordsRead(1)
+ record
+ },
+ context.taskMetrics().mergeShuffleReadMetrics()
+ )
- // An interruptible iterator must be used here in order to support task cancellation
- val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
+ // An interruptible iterator must be used here in order to support task cancellation
+ val interruptibleIter =
+ new InterruptibleIterator[(Any, Any)](context, metricIter)
- val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
+ val aggregatedIter: Iterator[Product2[K, C]] =
+ if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
// We are reading values that are already combined
- val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
- dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
+ val combinedKeyValuesIterator =
+ interruptibleIter.asInstanceOf[Iterator[(K, C)]]
+ dep.aggregator.get
+ .combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
// We don't know the value type, but also don't care -- the dependency *should*
// have made sure its compatible w/ this aggregator, which will convert the value
// type to the combined type C
- val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
+ val keyValuesIterator =
+ interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}
- // Sort the output if there is a sort ordering defined.
- val resultIter = dep.keyOrdering match {
- case Some(keyOrd: Ordering[K]) =>
- // Create an ExternalSorter to sort the data.
- val sorter =
- new ExternalSorter[K, C, C](context,
- ordering = Some(keyOrd), serializer = dep.serializer)
- sorter.insertAll(aggregatedIter)
- context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
- context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
- context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
- // Use completion callback to stop sorter if task was finished/cancelled.
- context.addTaskCompletionListener[Unit](_ => {
- sorter.stop()
- })
- CompletionIterator[Product2[K, C],
- Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
- case None =>
- aggregatedIter
- }
+ // Sort the output if there is a sort ordering defined.
+ val resultIter = dep.keyOrdering match {
+ case Some(keyOrd: Ordering[K]) =>
+ // Create an ExternalSorter to sort the data.
+ val sorter =
+ new ExternalSorter[K, C, C](
+ context,
+ ordering = Some(keyOrd),
+ serializer = dep.serializer
+ )
+ sorter.insertAll(aggregatedIter)
+ context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
+ context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
+ context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
+ // Use completion callback to stop sorter if task was finished/cancelled.
+ context.addTaskCompletionListener[Unit](_ => {
+ sorter.stop()
+ })
+ CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](
+ sorter.iterator,
+ sorter.stop()
+ )
+ case None =>
+ aggregatedIter
+ }
- resultIter match {
- case _: InterruptibleIterator[Product2[K, C]] => resultIter
- case _ =>
- // Use another interruptible iterator here to support task cancellation as aggregator
- // or(and) sorter may have consumed previous interruptible iterator.
- new InterruptibleIterator[Product2[K, C]](context, resultIter)
- }
+ resultIter match {
+ case _: InterruptibleIterator[Product2[K, C]] => resultIter
+ case _ =>
+ // Use another interruptible iterator here to support task cancellation as aggregator
+ // or(and) sorter may have consumed previous interruptible iterator.
+ new InterruptibleIterator[Product2[K, C]](context, resultIter)
}
+ }
}
diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleDataIO.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleDataIO.scala
index 6ecb74cd..2d961046 100755
--- a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleDataIO.scala
+++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleDataIO.scala
@@ -1,7 +1,7 @@
/*
-* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
-* See file LICENSE for terms.
-*/
+ * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
+ * See file LICENSE for terms.
+ */
package org.apache.spark.shuffle.compat.spark_3_0
import org.apache.spark.SparkConf
@@ -9,10 +9,11 @@ import org.apache.spark.internal.Logging
import org.apache.spark.shuffle.api.ShuffleExecutorComponents
import org.apache.spark.shuffle.sort.io.LocalDiskShuffleDataIO
-/**
- * Ucx local disk IO plugin to handle logic of writing to local disk and shuffle memory registration.
- */
-case class UcxLocalDiskShuffleDataIO(sparkConf: SparkConf) extends LocalDiskShuffleDataIO(sparkConf) with Logging {
+/** Ucx local disk IO plugin to handle logic of writing to local disk and shuffle memory registration.
+ */
+case class UcxLocalDiskShuffleDataIO(sparkConf: SparkConf)
+ extends LocalDiskShuffleDataIO(sparkConf)
+ with Logging {
override def executor(): ShuffleExecutorComponents = {
new UcxLocalDiskShuffleExecutorComponents(sparkConf)
diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleExecutorComponents.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleExecutorComponents.scala
index b8e5e72e..d6091e43 100755
--- a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleExecutorComponents.scala
+++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleExecutorComponents.scala
@@ -1,7 +1,7 @@
/*
-* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
-* See file LICENSE for terms.
-*/
+ * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
+ * See file LICENSE for terms.
+ */
package org.apache.spark.shuffle.compat.spark_3_0
import java.util
@@ -9,39 +9,67 @@ import java.util.Optional
import org.apache.spark.internal.Logging
import org.apache.spark.{SparkConf, SparkEnv}
-import org.apache.spark.shuffle.sort.io.{LocalDiskShuffleExecutorComponents, LocalDiskShuffleMapOutputWriter, LocalDiskSingleSpillMapOutputWriter}
+import org.apache.spark.shuffle.sort.io.{
+ LocalDiskShuffleExecutorComponents,
+ LocalDiskShuffleMapOutputWriter,
+ LocalDiskSingleSpillMapOutputWriter
+}
import org.apache.spark.shuffle.UcxShuffleManager
-import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, SingleSpillShuffleMapOutputWriter}
+import org.apache.spark.shuffle.api.{
+ ShuffleMapOutputWriter,
+ SingleSpillShuffleMapOutputWriter
+}
-/**
- * Entry point to UCX executor.
- */
+/** Entry point to UCX executor.
+ */
class UcxLocalDiskShuffleExecutorComponents(sparkConf: SparkConf)
- extends LocalDiskShuffleExecutorComponents(sparkConf) with Logging{
+ extends LocalDiskShuffleExecutorComponents(sparkConf)
+ with Logging {
private var blockResolver: UcxShuffleBlockResolver = _
- override def initializeExecutor(appId: String, execId: String, extraConfigs: util.Map[String, String]): Unit = {
- val ucxShuffleManager = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager]
+ override def initializeExecutor(
+ appId: String,
+ execId: String,
+ extraConfigs: util.Map[String, String]
+ ): Unit = {
+ val ucxShuffleManager =
+ SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager]
ucxShuffleManager.startUcxNodeIfMissing()
blockResolver = ucxShuffleManager.shuffleBlockResolver
}
- override def createMapOutputWriter(shuffleId: Int, mapTaskId: Long, numPartitions: Int): ShuffleMapOutputWriter = {
+ override def createMapOutputWriter(
+ shuffleId: Int,
+ mapTaskId: Long,
+ numPartitions: Int
+ ): ShuffleMapOutputWriter = {
if (blockResolver == null) {
throw new IllegalStateException(
- "Executor components must be initialized before getting writers.")
+ "Executor components must be initialized before getting writers."
+ )
}
new LocalDiskShuffleMapOutputWriter(
- shuffleId, mapTaskId, numPartitions, blockResolver, sparkConf)
+ shuffleId,
+ mapTaskId,
+ numPartitions,
+ blockResolver,
+ sparkConf
+ )
}
- override def createSingleFileMapOutputWriter(shuffleId: Int, mapId: Long): Optional[SingleSpillShuffleMapOutputWriter] = {
+ override def createSingleFileMapOutputWriter(
+ shuffleId: Int,
+ mapId: Long
+ ): Optional[SingleSpillShuffleMapOutputWriter] = {
if (blockResolver == null) {
throw new IllegalStateException(
- "Executor components must be initialized before getting writers.")
+ "Executor components must be initialized before getting writers."
+ )
}
- Optional.of(new LocalDiskSingleSpillMapOutputWriter(shuffleId, mapId, blockResolver))
+ Optional.of(
+ new LocalDiskSingleSpillMapOutputWriter(shuffleId, mapId, blockResolver)
+ )
}
}
diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleBlockResolver.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleBlockResolver.scala
index 8e172f96..d2a762da 100755
--- a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleBlockResolver.scala
+++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleBlockResolver.scala
@@ -1,7 +1,7 @@
/*
-* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
-* See file LICENSE for terms.
-*/
+ * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+ * See file LICENSE for terms.
+ */
package org.apache.spark.shuffle.compat.spark_3_0
import java.io.{File, RandomAccessFile}
@@ -9,29 +9,39 @@ import java.io.{File, RandomAccessFile}
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.network.shuffle.ExecutorDiskUtils
import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID
-import org.apache.spark.shuffle.{CommonUcxShuffleBlockResolver, CommonUcxShuffleManager}
+import org.apache.spark.shuffle.{
+ CommonUcxShuffleBlockResolver,
+ CommonUcxShuffleManager
+}
import org.apache.spark.storage.ShuffleIndexBlockId
-/**
- * Mapper entry point for UcxShuffle plugin. Performs memory registration
- * of data and index files and publish addresses to driver metadata buffer.
- */
+/** Mapper entry point for UcxShuffle plugin. Performs memory registration
+ * of data and index files and publish addresses to driver metadata buffer.
+ */
class UcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager)
- extends CommonUcxShuffleBlockResolver(ucxShuffleManager) {
+ extends CommonUcxShuffleBlockResolver(ucxShuffleManager) {
private def getIndexFile(
- shuffleId: Int,
- mapId: Long,
- dirs: Option[Array[String]] = None): File = {
+ shuffleId: Int,
+ mapId: Long,
+ dirs: Option[Array[String]] = None
+ ): File = {
val blockId = ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID)
val blockManager = SparkEnv.get.blockManager
dirs
- .map(ExecutorDiskUtils.getFile(_, blockManager.subDirsPerLocalDir, blockId.name))
+ .map(
+ ExecutorDiskUtils
+ .getFile(_, blockManager.subDirsPerLocalDir, blockId.name)
+ )
.getOrElse(blockManager.diskBlockManager.getFile(blockId))
}
- override def writeIndexFileAndCommit(shuffleId: ShuffleId, mapId: Long,
- lengths: Array[Long], dataTmp: File): Unit = {
+ override def writeIndexFileAndCommit(
+ shuffleId: ShuffleId,
+ mapId: Long,
+ lengths: Array[Long],
+ dataTmp: File
+ ): Unit = {
super.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp)
// In Spark-3.0 MapId is long and unique among all jobs in spark. We need to use partitionId as offset
// in metadata buffer
@@ -47,6 +57,13 @@ class UcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager)
val indexFile = getIndexFile(shuffleId, mapId)
val indexBackFile = new RandomAccessFile(indexFile, "rw")
- writeIndexFileAndCommitCommon(shuffleId, partitionId, lengths, dataTmp, indexBackFile, dataBackFile)
+ writeIndexFileAndCommitCommon(
+ shuffleId,
+ partitionId,
+ lengths,
+ dataTmp,
+ indexBackFile,
+ dataBackFile
+ )
}
}
diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleManager.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleManager.scala
index 4d4f8c09..2e504c7a 100755
--- a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleManager.scala
+++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleManager.scala
@@ -1,40 +1,64 @@
/*
-* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
-* See file LICENSE for terms.
-*/
+ * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+ * See file LICENSE for terms.
+ */
package org.apache.spark.shuffle
import scala.collection.JavaConverters._
import org.apache.spark.shuffle.api.ShuffleExecutorComponents
-import org.apache.spark.shuffle.compat.spark_3_0.{UcxShuffleBlockResolver, UcxShuffleReader}
-import org.apache.spark.shuffle.sort.{SerializedShuffleHandle, SortShuffleWriter, UnsafeShuffleWriter}
+import org.apache.spark.shuffle.compat.spark_3_0.{
+ UcxShuffleBlockResolver,
+ UcxShuffleReader
+}
+import org.apache.spark.shuffle.sort.{
+ SerializedShuffleHandle,
+ SortShuffleWriter,
+ UnsafeShuffleWriter
+}
import org.apache.spark.util.ShutdownHookManager
import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext}
-/**
- * Main entry point of Ucx shuffle plugin. It extends spark's default SortShufflePlugin
- * and injects needed logic in override methods.
- */
-class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean) extends CommonUcxShuffleManager(conf, isDriver) {
+/** Main entry point of Ucx shuffle plugin. It extends spark's default SortShufflePlugin
+ * and injects needed logic in override methods.
+ */
+class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean)
+ extends CommonUcxShuffleManager(conf, isDriver) {
ShutdownHookManager.addShutdownHook(Int.MaxValue - 1)(stop)
- private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf)
+ private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(
+ conf
+ )
override val shuffleBlockResolver = new UcxShuffleBlockResolver(this)
- override def registerShuffle[K, V, C](shuffleId: ShuffleId, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+ override def registerShuffle[K, V, C](
+ shuffleId: ShuffleId,
+ dependency: ShuffleDependency[K, V, C]
+ ): ShuffleHandle = {
assume(isDriver)
- val numMaps = dependency.partitioner.numPartitions
- val baseHandle = super.registerShuffle(shuffleId, dependency).asInstanceOf[BaseShuffleHandle[K, V, C]]
+ val numMaps = dependency.rdd.getNumPartitions
+ val baseHandle = super
+ .registerShuffle(shuffleId, dependency)
+ .asInstanceOf[BaseShuffleHandle[K, V, C]]
registerShuffleCommon(baseHandle, shuffleId, numMaps)
}
- override def getWriter[K, V](handle: ShuffleHandle, mapId: Long, context: TaskContext,
- metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
- shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K, V, _]])
+ override def getWriter[K, V](
+ handle: ShuffleHandle,
+ mapId: Long,
+ context: TaskContext,
+ metrics: ShuffleWriteMetricsReporter
+ ): ShuffleWriter[K, V] = {
+ shuffleIdToHandle.putIfAbsent(
+ handle.shuffleId,
+ handle.asInstanceOf[UcxShuffleHandle[K, V, _]]
+ )
val env = SparkEnv.get
handle.asInstanceOf[UcxShuffleHandle[K, V, _]].baseHandle match {
- case unsafeShuffleHandle: SerializedShuffleHandle[K@unchecked, V@unchecked] =>
+ case unsafeShuffleHandle: SerializedShuffleHandle[
+ K @unchecked,
+ V @unchecked
+ ] =>
new UnsafeShuffleWriter(
env.blockManager,
context.taskMemoryManager(),
@@ -43,31 +67,54 @@ class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean) extends
context,
env.conf,
metrics,
- shuffleExecutorComponents)
- case other: BaseShuffleHandle[K@unchecked, V@unchecked, _] =>
+ shuffleExecutorComponents
+ )
+ case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
new SortShuffleWriter(
- shuffleBlockResolver, other, mapId, context, shuffleExecutorComponents)
+ shuffleBlockResolver,
+ other,
+ mapId,
+ context,
+ shuffleExecutorComponents
+ )
}
}
- override def getReader[K, C](handle: ShuffleHandle, startPartition: MapId, endPartition: MapId,
- context: TaskContext, metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+ override def getReader[K, C](
+ handle: ShuffleHandle,
+ startPartition: MapId,
+ endPartition: MapId,
+ context: TaskContext,
+ metrics: ShuffleReadMetricsReporter
+ ): ShuffleReader[K, C] = {
startUcxNodeIfMissing()
- shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K, _, C]])
- new UcxShuffleReader(handle.asInstanceOf[UcxShuffleHandle[K,_,C]], startPartition, endPartition,
- context, readMetrics = metrics, shouldBatchFetch = true)
+ shuffleIdToHandle.putIfAbsent(
+ handle.shuffleId,
+ handle.asInstanceOf[UcxShuffleHandle[K, _, C]]
+ )
+ new UcxShuffleReader(
+ handle.asInstanceOf[UcxShuffleHandle[K, _, C]],
+ startPartition,
+ endPartition,
+ context,
+ readMetrics = metrics,
+ shouldBatchFetch = true
+ )
}
-
- private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = {
- val executorComponents = ShuffleDataIOUtils.loadShuffleDataIO(conf).executor()
- val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX)
- .toMap
+ private def loadShuffleExecutorComponents(
+ conf: SparkConf
+ ): ShuffleExecutorComponents = {
+ val executorComponents =
+ ShuffleDataIOUtils.loadShuffleDataIO(conf).executor()
+ val extraConfigs =
+ conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX).toMap
executorComponents.initializeExecutor(
conf.getAppId,
SparkEnv.get.executorId,
- extraConfigs.asJava)
+ extraConfigs.asJava
+ )
executorComponents
}
diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleReader.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleReader.scala
index 7dc34a7b..5eed9d37 100755
--- a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleReader.scala
+++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleReader.scala
@@ -1,7 +1,7 @@
/*
-* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
-* See file LICENSE for terms.
-*/
+ * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+ * See file LICENSE for terms.
+ */
package org.apache.spark.shuffle.compat.spark_3_0
import java.io.InputStream
@@ -13,161 +13,215 @@ import org.apache.spark.internal.{Logging, config}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.shuffle.ucx.reducer.compat.spark_3_0.UcxShuffleClient
-import org.apache.spark.shuffle.{ShuffleReadMetricsReporter, ShuffleReader, UcxShuffleHandle, UcxShuffleManager}
-import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockBatchId, ShuffleBlockFetcherIterator, ShuffleBlockId}
+import org.apache.spark.shuffle.{
+ ShuffleReadMetricsReporter,
+ ShuffleReader,
+ UcxShuffleHandle,
+ UcxShuffleManager
+}
+import org.apache.spark.storage.{
+ BlockId,
+ BlockManager,
+ ShuffleBlockBatchId,
+ ShuffleBlockFetcherIterator,
+ ShuffleBlockId
+}
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter
-import org.apache.spark.{InterruptibleIterator, SparkEnv, SparkException, TaskContext}
-
+import org.apache.spark.{
+ InterruptibleIterator,
+ SparkEnv,
+ SparkException,
+ TaskContext
+}
-/**
- * Extension of Spark's shuffe reader with a logic of injection UcxShuffleClient,
- * and lazy progress only when result queue is empty.
- */
-class UcxShuffleReader[K, C](handle: UcxShuffleHandle[K, _, C],
- startPartition: Int,
- endPartition: Int,
- context: TaskContext,
- serializerManager: SerializerManager = SparkEnv.get.serializerManager,
- blockManager: BlockManager = SparkEnv.get.blockManager,
- readMetrics: ShuffleReadMetricsReporter,
- shouldBatchFetch: Boolean = false) extends ShuffleReader[K, C] with Logging {
-
- private val dep = handle.baseHandle.dependency
-
- /** Read the combined key-values for this reduce task */
- override def read(): Iterator[Product2[K, C]] = {
- val (blocksByAddressIterator1, blocksByAddressIterator2) = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(
- handle.shuffleId, startPartition, endPartition).duplicate
- val mapIdToBlockIndex = blocksByAddressIterator2.flatMap{
- case (_, blocks) => blocks.map {
- case (blockId, _, mapIdx) => blockId match {
- case x: ShuffleBlockId => (x.mapId.asInstanceOf[java.lang.Long], mapIdx.asInstanceOf[java.lang.Integer])
- case x: ShuffleBlockBatchId => (x.mapId.asInstanceOf[java.lang.Long], mapIdx.asInstanceOf[java.lang.Integer])
+/** Extension of Spark's shuffe reader with a logic of injection UcxShuffleClient,
+ * and lazy progress only when result queue is empty.
+ */
+class UcxShuffleReader[K, C](
+ handle: UcxShuffleHandle[K, _, C],
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ serializerManager: SerializerManager = SparkEnv.get.serializerManager,
+ blockManager: BlockManager = SparkEnv.get.blockManager,
+ readMetrics: ShuffleReadMetricsReporter,
+ shouldBatchFetch: Boolean = false
+) extends ShuffleReader[K, C]
+ with Logging {
+
+ private val dep = handle.baseHandle.dependency
+
+ /** Read the combined key-values for this reduce task */
+ override def read(): Iterator[Product2[K, C]] = {
+ val (blocksByAddressIterator1, blocksByAddressIterator2) =
+ SparkEnv.get.mapOutputTracker
+ .getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition)
+ .duplicate
+ val mapIdToBlockIndex = blocksByAddressIterator2.flatMap {
+ case (_, blocks) =>
+ blocks.map { case (blockId, _, mapIdx) =>
+ blockId match {
+ case x: ShuffleBlockId =>
+ (
+ x.mapId.asInstanceOf[java.lang.Long],
+ mapIdx.asInstanceOf[java.lang.Integer]
+ )
+ case x: ShuffleBlockBatchId =>
+ (
+ x.mapId.asInstanceOf[java.lang.Long],
+ mapIdx.asInstanceOf[java.lang.Integer]
+ )
case _ => throw new SparkException("Unknown block")
}
}
- }.toMap
-
- val workerWrapper = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager]
- .ucxNode.getThreadLocalWorker
- val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics()
- val shuffleClient = new UcxShuffleClient(handle.shuffleId, workerWrapper, mapIdToBlockIndex.asJava, shuffleMetrics)
- val shuffleIterator = new ShuffleBlockFetcherIterator(
- context,
- shuffleClient,
- blockManager,
- blocksByAddressIterator1,
- serializerManager.wrapStream,
- // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
- SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024,
- SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT),
- SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
- SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
- SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT),
- SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY),
- readMetrics,
- fetchContinuousBlocksInBatch)
-
- val wrappedStreams = shuffleIterator.toCompletionIterator
-
- // Ucx shuffle logic
- // Java reflection to get access to private results queue
- val queueField = shuffleIterator.getClass.getDeclaredField(
- "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results")
- queueField.setAccessible(true)
- val resultQueue = queueField.get(shuffleIterator).asInstanceOf[LinkedBlockingQueue[_]]
-
- // Do progress if queue is empty before calling next on ShuffleIterator
- val ucxWrappedStream = new Iterator[(BlockId, InputStream)] {
- override def next(): (BlockId, InputStream) = {
- val startTime = System.currentTimeMillis()
- workerWrapper.fillQueueWithBlocks(resultQueue)
- readMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime)
- wrappedStreams.next()
- }
+ }.toMap
+
+ val workerWrapper = SparkEnv.get.shuffleManager
+ .asInstanceOf[UcxShuffleManager]
+ .ucxNode
+ .getThreadLocalWorker
+ val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics()
+ val shuffleClient = new UcxShuffleClient(
+ handle.shuffleId,
+ workerWrapper,
+ mapIdToBlockIndex.asJava,
+ shuffleMetrics
+ )
+ val shuffleIterator = new ShuffleBlockFetcherIterator(
+ context,
+ shuffleClient,
+ blockManager,
+ blocksByAddressIterator1,
+ serializerManager.wrapStream,
+ // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
+ SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024,
+ SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT),
+ SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
+ SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
+ SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT),
+ SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY),
+ readMetrics,
+ fetchContinuousBlocksInBatch
+ )
+
+ val wrappedStreams = shuffleIterator.toCompletionIterator
+
+ // Ucx shuffle logic
+ // Java reflection to get access to private results queue
+ val queueField = shuffleIterator.getClass.getDeclaredField(
+ "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results"
+ )
+ queueField.setAccessible(true)
+ val resultQueue =
+ queueField.get(shuffleIterator).asInstanceOf[LinkedBlockingQueue[_]]
+
+ // Do progress if queue is empty before calling next on ShuffleIterator
+ val ucxWrappedStream = new Iterator[(BlockId, InputStream)] {
+ override def next(): (BlockId, InputStream) = {
+ val startTime = System.currentTimeMillis()
+ workerWrapper.fillQueueWithBlocks(resultQueue)
+ readMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime)
+ wrappedStreams.next()
+ }
- override def hasNext: Boolean = {
- val result = wrappedStreams.hasNext
- if (!result) {
- shuffleClient.close()
- }
- result
+ override def hasNext: Boolean = {
+ val result = wrappedStreams.hasNext
+ if (!result) {
+ shuffleClient.close()
}
+ result
}
- // End of ucx shuffle logic
-
- val serializerInstance = dep.serializer.newInstance()
-
- // Create a key/value iterator for each stream
- val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) =>
- // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
- // NextIterator. The NextIterator makes sure that close() is called on the
- // underlying InputStream when all records have been read.
- serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
- }
+ }
+ // End of ucx shuffle logic
- // Update the context task metrics for each record read.
- val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
- recordIter.map { record =>
- readMetrics.incRecordsRead(1)
- record
- },
- context.taskMetrics().mergeShuffleReadMetrics())
+ val serializerInstance = dep.serializer.newInstance()
- // An interruptible iterator must be used here in order to support task cancellation
- val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
+ // Create a key/value iterator for each stream
+ val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) =>
+ // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
+ // NextIterator. The NextIterator makes sure that close() is called on the
+ // underlying InputStream when all records have been read.
+ serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
+ }
- val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
+ // Update the context task metrics for each record read.
+ val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
+ recordIter.map { record =>
+ readMetrics.incRecordsRead(1)
+ record
+ },
+ context.taskMetrics().mergeShuffleReadMetrics()
+ )
+
+ // An interruptible iterator must be used here in order to support task cancellation
+ val interruptibleIter =
+ new InterruptibleIterator[(Any, Any)](context, metricIter)
+
+ val aggregatedIter: Iterator[Product2[K, C]] =
+ if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
// We are reading values that are already combined
- val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
- dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
+ val combinedKeyValuesIterator =
+ interruptibleIter.asInstanceOf[Iterator[(K, C)]]
+ dep.aggregator.get
+ .combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
// We don't know the value type, but also don't care -- the dependency *should*
// have made sure its compatible w/ this aggregator, which will convert the value
// type to the combined type C
- val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
+ val keyValuesIterator =
+ interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}
- // Sort the output if there is a sort ordering defined.
- val resultIter = dep.keyOrdering match {
- case Some(keyOrd: Ordering[K]) =>
- // Create an ExternalSorter to sort the data.
- val sorter =
- new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
- sorter.insertAll(aggregatedIter)
- context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
- context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
- context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
- // Use completion callback to stop sorter if task was finished/cancelled.
- context.addTaskCompletionListener[Unit](_ => {
- sorter.stop()
- })
- CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
- case None =>
- aggregatedIter
- }
+ // Sort the output if there is a sort ordering defined.
+ val resultIter = dep.keyOrdering match {
+ case Some(keyOrd: Ordering[K]) =>
+ // Create an ExternalSorter to sort the data.
+ val sorter =
+ new ExternalSorter[K, C, C](
+ context,
+ ordering = Some(keyOrd),
+ serializer = dep.serializer
+ )
+ sorter.insertAll(aggregatedIter)
+ context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
+ context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
+ context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
+ // Use completion callback to stop sorter if task was finished/cancelled.
+ context.addTaskCompletionListener[Unit](_ => {
+ sorter.stop()
+ })
+ CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](
+ sorter.iterator,
+ sorter.stop()
+ )
+ case None =>
+ aggregatedIter
+ }
- resultIter match {
- case _: InterruptibleIterator[Product2[K, C]] => resultIter
- case _ =>
- // Use another interruptible iterator here to support task cancellation as aggregator
- // or(and) sorter may have consumed previous interruptible iterator.
- new InterruptibleIterator[Product2[K, C]](context, resultIter)
- }
+ resultIter match {
+ case _: InterruptibleIterator[Product2[K, C]] => resultIter
+ case _ =>
+ // Use another interruptible iterator here to support task cancellation as aggregator
+ // or(and) sorter may have consumed previous interruptible iterator.
+ new InterruptibleIterator[Product2[K, C]](context, resultIter)
}
+ }
private def fetchContinuousBlocksInBatch: Boolean = {
val conf = SparkEnv.get.conf
- val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects
+ val serializerRelocatable =
+ dep.serializer.supportsRelocationOfSerializedObjects
val compressed = conf.get(config.SHUFFLE_COMPRESS)
val codecConcatenation = if (compressed) {
- CompressionCodec.supportsConcatenationOfSerializedStreams(CompressionCodec.createCodec(conf))
+ CompressionCodec.supportsConcatenationOfSerializedStreams(
+ CompressionCodec.createCodec(conf)
+ )
} else {
true
}
@@ -176,12 +230,14 @@ class UcxShuffleReader[K, C](handle: UcxShuffleHandle[K, _, C],
val doBatchFetch = shouldBatchFetch && serializerRelocatable &&
(!compressed || codecConcatenation) && !useOldFetchProtocol
if (shouldBatchFetch && !doBatchFetch) {
- logWarning("The feature tag of continuous shuffle block fetching is set to true, but " +
- "we can not enable the feature because other conditions are not satisfied. " +
- s"Shuffle compress: $compressed, serializer ${dep.serializer.getClass.getName} " +
- s"relocatable: $serializerRelocatable, " +
- s"codec concatenation: $codecConcatenation, use old shuffle fetch protocol: " +
- s"$useOldFetchProtocol.")
+ logWarning(
+ "The feature tag of continuous shuffle block fetching is set to true, but " +
+ "we can not enable the feature because other conditions are not satisfied. " +
+ s"Shuffle compress: $compressed, serializer ${dep.serializer.getClass.getName} " +
+ s"relocatable: $serializerRelocatable, " +
+ s"codec concatenation: $codecConcatenation, use old shuffle fetch protocol: " +
+ s"$useOldFetchProtocol."
+ )
}
doBatchFetch
}