diff --git a/.scalafmt.conf b/.scalafmt.conf new file mode 100644 index 00000000..834f2d20 --- /dev/null +++ b/.scalafmt.conf @@ -0,0 +1 @@ +version = 2.7.5 \ No newline at end of file diff --git a/README.md b/README.md index b9cb6da3..b23b6591 100755 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ Build instructions: ``` % git clone https://github.com/openucx/sparkucx % cd sparkucx -% mvn -DskipTests clean package -Pspark-2.4 +% mvn -DskipTests clean package -Pspark-3.0 ``` ### Performance diff --git a/pom.xml b/pom.xml index b23b54ac..1e2826ad 100755 --- a/pom.xml +++ b/pom.xml @@ -149,7 +149,7 @@ See file LICENSE for terms. org.openucx jucx - 1.11.0-rc3 + 1.14.1 diff --git a/src/main/java/org/apache/spark/shuffle/ucx/UcxNode.java b/src/main/java/org/apache/spark/shuffle/ucx/UcxNode.java index 625281b9..5f0f73f6 100755 --- a/src/main/java/org/apache/spark/shuffle/ucx/UcxNode.java +++ b/src/main/java/org/apache/spark/shuffle/ucx/UcxNode.java @@ -33,200 +33,201 @@ * Single instance class per spark process, that keeps UcpContext, memory and worker pools. */ public class UcxNode implements Closeable { - // Global - private static final Logger logger = LoggerFactory.getLogger(UcxNode.class); - private final boolean isDriver; - private final UcpContext context; - private final MemoryPool memoryPool; - private final UcpWorkerParams workerParams = new UcpWorkerParams(); - private final UcpWorker globalWorker; - private final UcxShuffleConf conf; - // Mapping from spark's entity of BlockManagerId to UcxEntity workerAddress. - private static final ConcurrentHashMap workerAdresses = - new ConcurrentHashMap<>(); - private final Thread listenerProgressThread; - private boolean closed = false; - - // Driver - private UcpListener listener; - // Mapping from UcpEndpoint to ByteBuffer of RPC message, to introduce executor to cluster - private static final ConcurrentHashMap rpcConnections = - new ConcurrentHashMap<>(); - private List backwardEndpoints = new ArrayList<>(); - - // Executor - private UcpEndpoint globalDriverEndpoint; - // Keep track of allocated workers to correctly close them. - private static final Set allocatedWorkers = ConcurrentHashMap.newKeySet(); - private final ThreadLocal threadLocalWorker; - - public UcxNode(UcxShuffleConf conf, boolean isDriver) { - this.conf = conf; - this.isDriver = isDriver; - UcpParams params = new UcpParams().requestTagFeature() - .requestRmaFeature().requestWakeupFeature() - .setMtWorkersShared(true); - context = new UcpContext(params); - memoryPool = new MemoryPool(context, conf); - globalWorker = context.newWorker(workerParams); - InetSocketAddress driverAddress = new InetSocketAddress(conf.driverHost(), conf.driverPort()); - - if (isDriver) { - startDriver(driverAddress); - } else { - startExecutor(driverAddress); + // Global + private static final Logger logger = LoggerFactory.getLogger(UcxNode.class); + private final boolean isDriver; + private final UcpContext context; + private final MemoryPool memoryPool; + private final UcpWorkerParams workerParams = new UcpWorkerParams(); + private final UcpWorker globalWorker; + private final UcxShuffleConf conf; + // Mapping from spark's entity of BlockManagerId to UcxEntity workerAddress. + private static final ConcurrentHashMap workerAdresses = + new ConcurrentHashMap<>(); + private final Thread listenerProgressThread; + private boolean closed = false; + + // Driver + private UcpListener listener; + // Mapping from UcpEndpoint to ByteBuffer of RPC message, to introduce executor to cluster + private static final ConcurrentHashMap rpcConnections = + new ConcurrentHashMap<>(); + private List backwardEndpoints = new ArrayList<>(); + + // Executor + private UcpEndpoint globalDriverEndpoint; + // Keep track of allocated workers to correctly close them. + private static final Set allocatedWorkers = ConcurrentHashMap.newKeySet(); + private final ThreadLocal threadLocalWorker; + + public UcxNode(UcxShuffleConf conf, boolean isDriver) { + this.conf = conf; + this.isDriver = isDriver; + UcpParams params = new UcpParams().requestTagFeature() + .requestRmaFeature().requestWakeupFeature() + .setMtWorkersShared(true); + context = new UcpContext(params); + memoryPool = new MemoryPool(context, conf); + globalWorker = context.newWorker(workerParams); + InetSocketAddress driverAddress = new InetSocketAddress(conf.driverHost(), conf.driverPort()); + + if (isDriver) { + startDriver(driverAddress); + } else { + startExecutor(driverAddress); + } + + // Global listener thread, that keeps lazy progress for connection establishment + listenerProgressThread = new UcxListenerThread(this, isDriver); + listenerProgressThread.start(); + + if (!isDriver) { + memoryPool.preAlocate(); + } + + threadLocalWorker = ThreadLocal.withInitial(() -> { + UcpWorker localWorker = context.newWorker(workerParams); + UcxWorkerWrapper result = new UcxWorkerWrapper(localWorker, + conf, allocatedWorkers.size()); + if (result.id() > conf.coresPerProcess()) { + logger.warn("Thread: {} - creates new worker {} > numCores", + Thread.currentThread().getId(), result.id()); + } + allocatedWorkers.add(result); + return result; + }); + } + + private void startDriver(InetSocketAddress driverAddress) { + // 1. Start listener on a driver and accept RPC messages from executors with their + // worker addresses + UcpListenerParams listenerParams = new UcpListenerParams().setSockAddr(driverAddress) + .setConnectionHandler(ucpConnectionRequest -> + backwardEndpoints.add(globalWorker.newEndpoint(new UcpEndpointParams() + .setConnectionRequest(ucpConnectionRequest)))); + listener = globalWorker.newListener(listenerParams); + logger.info("Started UcxNode on {}", driverAddress); } - // Global listener thread, that keeps lazy progress for connection establishment - listenerProgressThread = new UcxListenerThread(this, isDriver); - listenerProgressThread.start(); + /** + * Allocates ByteBuffer from memoryPool and serializes there workerAddress, + * followed by BlockManagerID + * + * @return RegisteredMemory that holds metadata buffer. + */ + private RegisteredMemory buildMetadataBuffer() { + BlockManagerId blockManagerId = SparkEnv.get().blockManager().blockManagerId(); + ByteBuffer workerAddresses = globalWorker.getAddress(); + + RegisteredMemory metadataMemory = memoryPool.get(conf.metadataRPCBufferSize()); + ByteBuffer metadataBuffer = metadataMemory.getBuffer(); + metadataBuffer.putInt(workerAddresses.capacity()); + metadataBuffer.put(workerAddresses); + try { + SerializableBlockManagerID.serializeBlockManagerID(blockManagerId, metadataBuffer); + } catch (IOException e) { + String errorMsg = String.format("Failed to serialize %s: %s", blockManagerId, + e.getMessage()); + throw new UcxException(errorMsg); + } + metadataBuffer.clear(); + return metadataMemory; + } - if (!isDriver) { - memoryPool.preAlocate(); + private void startExecutor(InetSocketAddress driverAddress) { + // 1. Executor: connect to driver using sockaddr + // and send it's worker address followed by BlockManagerID. + globalDriverEndpoint = globalWorker.newEndpoint( + new UcpEndpointParams().setSocketAddress(driverAddress).setPeerErrorHandlingMode() + ); + + RegisteredMemory metadataMemory = buildMetadataBuffer(); + // TODO: send using stream API when it would be available in jucx. + globalDriverEndpoint.sendTaggedNonBlocking(metadataMemory.getBuffer(), new UcxCallback() { + @Override + public void onSuccess(UcpRequest request) { + memoryPool.put(metadataMemory); + } + }); } - threadLocalWorker = ThreadLocal.withInitial(() -> { - UcpWorker localWorker = context.newWorker(workerParams); - UcxWorkerWrapper result = new UcxWorkerWrapper(localWorker, - conf, allocatedWorkers.size()); - if (result.id() > conf.coresPerProcess()) { - logger.warn("Thread: {} - creates new worker {} > numCores", - Thread.currentThread().getId(), result.id()); - } - allocatedWorkers.add(result); - return result; - }); - } - - private void startDriver(InetSocketAddress driverAddress) { - // 1. Start listener on a driver and accept RPC messages from executors with their - // worker addresses - UcpListenerParams listenerParams = new UcpListenerParams().setSockAddr(driverAddress) - .setConnectionHandler(ucpConnectionRequest -> - backwardEndpoints.add(globalWorker.newEndpoint(new UcpEndpointParams() - .setConnectionRequest(ucpConnectionRequest)))); - listener = globalWorker.newListener(listenerParams); - logger.info("Started UcxNode on {}", driverAddress); - } - - /** - * Allocates ByteBuffer from memoryPool and serializes there workerAddress, - * followed by BlockManagerID - * @return RegisteredMemory that holds metadata buffer. - */ - private RegisteredMemory buildMetadataBuffer() { - BlockManagerId blockManagerId = SparkEnv.get().blockManager().blockManagerId(); - ByteBuffer workerAddresses = globalWorker.getAddress(); - - RegisteredMemory metadataMemory = memoryPool.get(conf.metadataRPCBufferSize()); - ByteBuffer metadataBuffer = metadataMemory.getBuffer(); - metadataBuffer.putInt(workerAddresses.capacity()); - metadataBuffer.put(workerAddresses); - try { - SerializableBlockManagerID.serializeBlockManagerID(blockManagerId, metadataBuffer); - } catch (IOException e) { - String errorMsg = String.format("Failed to serialize %s: %s", blockManagerId, - e.getMessage()); - throw new UcxException(errorMsg); + public UcxShuffleConf getConf() { + return conf; } - metadataBuffer.clear(); - return metadataMemory; - } - - private void startExecutor(InetSocketAddress driverAddress) { - // 1. Executor: connect to driver using sockaddr - // and send it's worker address followed by BlockManagerID. - globalDriverEndpoint = globalWorker.newEndpoint( - new UcpEndpointParams().setSocketAddress(driverAddress).setPeerErrorHandlingMode() - ); - - RegisteredMemory metadataMemory = buildMetadataBuffer(); - // TODO: send using stream API when it would be available in jucx. - globalDriverEndpoint.sendTaggedNonBlocking(metadataMemory.getBuffer(), new UcxCallback() { - @Override - public void onSuccess(UcpRequest request) { - memoryPool.put(metadataMemory); - } - }); - } - - public UcxShuffleConf getConf() { - return conf; - } - - public UcpWorker getGlobalWorker() { - return globalWorker; - } - - public MemoryPool getMemoryPool() { - return memoryPool; - } - - public UcpContext getContext() { - return context; - } - - /** - * Get or initialize worker for current thread - */ - public UcxWorkerWrapper getThreadLocalWorker() { - return threadLocalWorker.get(); - } - - public static ConcurrentMap getWorkerAddresses() { - return workerAdresses; - } - - public static ConcurrentMap getRpcConnections() { - return rpcConnections; - } - - private void stopDriver() { - if (listener != null) { - listener.close(); - listener = null; + + public UcpWorker getGlobalWorker() { + return globalWorker; } - rpcConnections.keySet().forEach(UcpEndpoint::close); - rpcConnections.clear(); - backwardEndpoints.forEach(UcpEndpoint::close); - backwardEndpoints.clear(); - } - - private void stopExecutor() { - if (globalDriverEndpoint != null) { - globalDriverEndpoint.close(); - globalDriverEndpoint = null; + + public MemoryPool getMemoryPool() { + return memoryPool; } - allocatedWorkers.forEach(UcxWorkerWrapper::close); - allocatedWorkers.clear(); - } - - @Override - public void close() { - threadLocalWorker.remove(); - synchronized (this) { - if (!closed) { - logger.info("Stopping UcxNode"); - listenerProgressThread.interrupt(); - globalWorker.signal(); - try { - listenerProgressThread.join(); - if (isDriver) { - stopDriver(); - } else { - stopExecutor(); - } - memoryPool.close(); - globalWorker.close(); - context.close(); - closed = true; - } catch (InterruptedException e) { - logger.error(e.getMessage()); - Thread.currentThread().interrupt(); - } catch (Exception ex) { - logger.warn(ex.getLocalizedMessage()); + + public UcpContext getContext() { + return context; + } + + /** + * Get or initialize worker for current thread + */ + public UcxWorkerWrapper getThreadLocalWorker() { + return threadLocalWorker.get(); + } + + public static ConcurrentMap getWorkerAddresses() { + return workerAdresses; + } + + public static ConcurrentMap getRpcConnections() { + return rpcConnections; + } + + private void stopDriver() { + if (listener != null) { + listener.close(); + listener = null; + } + rpcConnections.keySet().forEach(UcpEndpoint::close); + rpcConnections.clear(); + backwardEndpoints.forEach(UcpEndpoint::close); + backwardEndpoints.clear(); + } + + private void stopExecutor() { + if (globalDriverEndpoint != null) { + globalDriverEndpoint.close(); + globalDriverEndpoint = null; + } + allocatedWorkers.forEach(UcxWorkerWrapper::close); + allocatedWorkers.clear(); + } + + @Override + public void close() { + threadLocalWorker.remove(); + synchronized (this) { + if (!closed) { + logger.info("Stopping UcxNode"); + listenerProgressThread.interrupt(); + globalWorker.signal(); + try { + listenerProgressThread.join(); + if (isDriver) { + stopDriver(); + } else { + stopExecutor(); + } + memoryPool.close(); + globalWorker.close(); + context.close(); + closed = true; + } catch (InterruptedException e) { + logger.error(e.getMessage()); + Thread.currentThread().interrupt(); + } catch (Exception ex) { + logger.warn(ex.getLocalizedMessage()); + } + } } - } } - } } diff --git a/src/main/java/org/apache/spark/shuffle/ucx/UnsafeUtils.java b/src/main/java/org/apache/spark/shuffle/ucx/UnsafeUtils.java index 9e9e0843..adbfc8c8 100755 --- a/src/main/java/org/apache/spark/shuffle/ucx/UnsafeUtils.java +++ b/src/main/java/org/apache/spark/shuffle/ucx/UnsafeUtils.java @@ -20,58 +20,59 @@ * Java's native mmap functionality, that allows to mmap files > 2GB. */ public class UnsafeUtils { - private static final Method mmap; - private static final Method unmmap; - private static final Logger logger = LoggerFactory.getLogger(UnsafeUtils.class); + private static final Method mmap; + private static final Method unmmap; + private static final Logger logger = LoggerFactory.getLogger(UnsafeUtils.class); - private static final Constructor directBufferConstructor; + private static final Constructor directBufferConstructor; - public static final int LONG_SIZE = 8; - public static final int INT_SIZE = 4; + public static final int LONG_SIZE = 8; + public static final int INT_SIZE = 4; - static { - try { - mmap = FileChannelImpl.class.getDeclaredMethod("map0", int.class, long.class, long.class); - mmap.setAccessible(true); - unmmap = FileChannelImpl.class.getDeclaredMethod("unmap0", long.class, long.class); - unmmap.setAccessible(true); - Class classDirectByteBuffer = Class.forName("java.nio.DirectByteBuffer"); - directBufferConstructor = classDirectByteBuffer.getDeclaredConstructor(long.class, int.class); - directBufferConstructor.setAccessible(true); - } catch (Exception e) { - throw new RuntimeException(e); + static { + try { + mmap = FileChannelImpl.class.getDeclaredMethod("map0", int.class, long.class, long.class); + mmap.setAccessible(true); + unmmap = FileChannelImpl.class.getDeclaredMethod("unmap0", long.class, long.class); + unmmap.setAccessible(true); + Class classDirectByteBuffer = Class.forName("java.nio.DirectByteBuffer"); + directBufferConstructor = classDirectByteBuffer.getDeclaredConstructor(long.class, int.class); + directBufferConstructor.setAccessible(true); + } catch (Exception e) { + throw new RuntimeException(e); + } } - } - private UnsafeUtils() {} + private UnsafeUtils() { + } - public static long mmap(FileChannel fileChannel, long offset, long length) { - long result; - try { - result = (long)mmap.invoke(fileChannel, 1, offset, length); - } catch (Exception e) { - logger.error("MMap({}, {}) failed: {}", offset, length, e.getMessage()); - throw new UcxException(e.getMessage()); + public static long mmap(FileChannel fileChannel, long offset, long length) { + long result; + try { + result = (long) mmap.invoke(fileChannel, 1, offset, length); + } catch (Exception e) { + logger.error("MMap({}, {}) failed: {}", offset, length, e.getMessage()); + throw new UcxException(e.getMessage()); + } + return result; } - return result; - } - public static void munmap(long address, long length) { - try { - unmmap.invoke(null, address, length); - } catch (IllegalAccessException | InvocationTargetException e) { - logger.error(e.getMessage()); + public static void munmap(long address, long length) { + try { + unmmap.invoke(null, address, length); + } catch (IllegalAccessException | InvocationTargetException e) { + logger.error(e.getMessage()); + } } - } - public static ByteBuffer getByteBuffer(long address, int length) throws IOException { - try { - return (ByteBuffer)directBufferConstructor.newInstance(address, length); - } catch (InvocationTargetException ex) { - throw new IOException("java.nio.DirectByteBuffer: " + - "InvocationTargetException: " + ex.getTargetException()); - } catch (Exception e) { - throw new IOException("java.nio.DirectByteBuffer exception: " + e.getMessage()); + public static ByteBuffer getByteBuffer(long address, int length) throws IOException { + try { + return (ByteBuffer) directBufferConstructor.newInstance(address, length); + } catch (InvocationTargetException ex) { + throw new IOException("java.nio.DirectByteBuffer: " + + "InvocationTargetException: " + ex.getTargetException()); + } catch (Exception e) { + throw new IOException("java.nio.DirectByteBuffer exception: " + e.getMessage()); + } } - } } diff --git a/src/main/java/org/apache/spark/shuffle/ucx/memory/MemoryPool.java b/src/main/java/org/apache/spark/shuffle/ucx/memory/MemoryPool.java index 200e9ee2..2d5979ec 100755 --- a/src/main/java/org/apache/spark/shuffle/ucx/memory/MemoryPool.java +++ b/src/main/java/org/apache/spark/shuffle/ucx/memory/MemoryPool.java @@ -25,155 +25,155 @@ * and registration during shuffle phase. */ public class MemoryPool implements Closeable { - private static final Logger logger = LoggerFactory.getLogger(MemoryPool.class); - - @Override - public void close() { - for (AllocatorStack stack: allocStackMap.values()) { - stack.close(); - logger.info("Stack of size {}. " + - "Total requests: {}, total allocations: {}, preAllocations: {}", - stack.length, stack.totalRequests.get(), stack.totalAlloc.get(), stack.preAllocs.get()); - } - allocStackMap.clear(); - } - - private class AllocatorStack implements Closeable { - private final AtomicInteger totalRequests = new AtomicInteger(0); - private final AtomicInteger totalAlloc = new AtomicInteger(0); - private final AtomicInteger preAllocs = new AtomicInteger(0); - private final ConcurrentLinkedDeque stack = new ConcurrentLinkedDeque<>(); - private final int length; - - private AllocatorStack(int length) { - this.length = length; - } + private static final Logger logger = LoggerFactory.getLogger(MemoryPool.class); - private RegisteredMemory get() { - RegisteredMemory result = stack.pollFirst(); - if (result == null) { - if (length < conf.minRegistrationSize()) { - int numBuffers = conf.minRegistrationSize() / length; - logger.debug("Allocating {} buffers of size {}", numBuffers, length); - preallocate(numBuffers); - result = stack.pollFirst(); - if (result == null) { - return get(); - } else { - result.getRefCount().incrementAndGet(); - } - } else { - UcpMemMapParams memMapParams = new UcpMemMapParams().setLength(length).allocate(); - UcpMemory memory = context.memoryMap(memMapParams); - ByteBuffer buffer; - try { - buffer = UcxUtils.getByteBufferView(memory.getAddress(), (int)memory.getLength()); - } catch (Exception e) { - throw new UcxException(e.getMessage()); - } - result = new RegisteredMemory(new AtomicInteger(1), memory, buffer); - totalAlloc.incrementAndGet(); + @Override + public void close() { + for (AllocatorStack stack : allocStackMap.values()) { + stack.close(); + logger.info("Stack of size {}. " + + "Total requests: {}, total allocations: {}, preAllocations: {}", + stack.length, stack.totalRequests.get(), stack.totalAlloc.get(), stack.preAllocs.get()); } - } else { - result.getRefCount().incrementAndGet(); - } - totalRequests.incrementAndGet(); - return result; + allocStackMap.clear(); } - private void put(RegisteredMemory registeredMemory) { - registeredMemory.getRefCount().decrementAndGet(); - stack.addLast(registeredMemory); + private class AllocatorStack implements Closeable { + private final AtomicInteger totalRequests = new AtomicInteger(0); + private final AtomicInteger totalAlloc = new AtomicInteger(0); + private final AtomicInteger preAllocs = new AtomicInteger(0); + private final ConcurrentLinkedDeque stack = new ConcurrentLinkedDeque<>(); + private final int length; + + private AllocatorStack(int length) { + this.length = length; + } + + private RegisteredMemory get() { + RegisteredMemory result = stack.pollFirst(); + if (result == null) { + if (length < conf.minRegistrationSize()) { + int numBuffers = conf.minRegistrationSize() / length; + logger.debug("Allocating {} buffers of size {}", numBuffers, length); + preallocate(numBuffers); + result = stack.pollFirst(); + if (result == null) { + return get(); + } else { + result.getRefCount().incrementAndGet(); + } + } else { + UcpMemMapParams memMapParams = new UcpMemMapParams().setLength(length).allocate(); + UcpMemory memory = context.memoryMap(memMapParams); + ByteBuffer buffer; + try { + buffer = UcxUtils.getByteBufferView(memory.getAddress(), (int) memory.getLength()); + } catch (Exception e) { + throw new UcxException(e.getMessage()); + } + result = new RegisteredMemory(new AtomicInteger(1), memory, buffer); + totalAlloc.incrementAndGet(); + } + } else { + result.getRefCount().incrementAndGet(); + } + totalRequests.incrementAndGet(); + return result; + } + + private void put(RegisteredMemory registeredMemory) { + registeredMemory.getRefCount().decrementAndGet(); + stack.addLast(registeredMemory); + } + + private void preallocate(int numBuffers) { + // Platform.allocateDirectBuffer supports only 2GB of buffer. + // Decrease number of buffers if total size of preAllocation > 2GB. + if ((long) length * (long) numBuffers > Integer.MAX_VALUE) { + numBuffers = Integer.MAX_VALUE / length; + } + + UcpMemMapParams memMapParams = new UcpMemMapParams().allocate().setLength(numBuffers * (long) length); + UcpMemory memory = context.memoryMap(memMapParams); + ByteBuffer buffer; + try { + buffer = UnsafeUtils.getByteBuffer(memory.getAddress(), numBuffers * length); + } catch (Exception ex) { + throw new UcxException(ex.getMessage()); + } + + AtomicInteger refCount = new AtomicInteger(numBuffers); + for (int i = 0; i < numBuffers; i++) { + buffer.position(i * length).limit(i * length + length); + final ByteBuffer slice = buffer.slice(); + RegisteredMemory registeredMemory = new RegisteredMemory(refCount, memory, slice); + put(registeredMemory); + } + preAllocs.incrementAndGet(); + totalAlloc.incrementAndGet(); + } + + @Override + public void close() { + while (!stack.isEmpty()) { + RegisteredMemory memory = stack.pollFirst(); + if (memory != null) { + memory.deregisterNativeMemory(); + } + } + } } - private void preallocate(int numBuffers) { - // Platform.allocateDirectBuffer supports only 2GB of buffer. - // Decrease number of buffers if total size of preAllocation > 2GB. - if ((long)length * (long)numBuffers > Integer.MAX_VALUE) { - numBuffers = Integer.MAX_VALUE / length; - } - - UcpMemMapParams memMapParams = new UcpMemMapParams().allocate().setLength(numBuffers * (long)length); - UcpMemory memory = context.memoryMap(memMapParams); - ByteBuffer buffer; - try { - buffer = UnsafeUtils.getByteBuffer(memory.getAddress(), numBuffers * length); - } catch (Exception ex) { - throw new UcxException(ex.getMessage()); - } - - AtomicInteger refCount = new AtomicInteger(numBuffers); - for (int i = 0; i < numBuffers; i++) { - buffer.position(i * length).limit(i * length + length); - final ByteBuffer slice = buffer.slice(); - RegisteredMemory registeredMemory = new RegisteredMemory(refCount, memory, slice); - put(registeredMemory); - } - preAllocs.incrementAndGet(); - totalAlloc.incrementAndGet(); + private final ConcurrentHashMap allocStackMap = + new ConcurrentHashMap<>(); + private final UcpContext context; + private final UcxShuffleConf conf; + + public MemoryPool(UcpContext context, UcxShuffleConf conf) { + this.context = context; + this.conf = conf; } - @Override - public void close() { - while (!stack.isEmpty()) { - RegisteredMemory memory = stack.pollFirst(); - if (memory != null) { - memory.deregisterNativeMemory(); + private long roundUpToTheNextPowerOf2(long length) { + // Round up length to the nearest power of two, or the minimum block size + if (length < conf.minBufferSize()) { + length = conf.minBufferSize(); + } else { + length--; + length |= length >> 1; + length |= length >> 2; + length |= length >> 4; + length |= length >> 8; + length |= length >> 16; + length++; } - } + return length; + } + + public RegisteredMemory get(int size) { + long roundedSize = roundUpToTheNextPowerOf2(size); + assert roundedSize < Integer.MAX_VALUE && roundedSize > 0; + AllocatorStack stack = + allocStackMap.computeIfAbsent((int) roundedSize, AllocatorStack::new); + RegisteredMemory result = stack.get(); + result.getBuffer().position(0).limit(size); + return result; } - } - - private final ConcurrentHashMap allocStackMap = - new ConcurrentHashMap<>(); - private final UcpContext context; - private final UcxShuffleConf conf; - - public MemoryPool(UcpContext context, UcxShuffleConf conf) { - this.context = context; - this.conf = conf; - } - - private long roundUpToTheNextPowerOf2(long length) { - // Round up length to the nearest power of two, or the minimum block size - if (length < conf.minBufferSize()) { - length = conf.minBufferSize(); - } else { - length--; - length |= length >> 1; - length |= length >> 2; - length |= length >> 4; - length |= length >> 8; - length |= length >> 16; - length++; + + public void put(RegisteredMemory memory) { + AllocatorStack allocatorStack = allocStackMap.get(memory.getBuffer().capacity()); + if (allocatorStack != null) { + allocatorStack.put(memory); + } } - return length; - } - - public RegisteredMemory get(int size) { - long roundedSize = roundUpToTheNextPowerOf2(size); - assert roundedSize < Integer.MAX_VALUE && roundedSize > 0; - AllocatorStack stack = - allocStackMap.computeIfAbsent((int)roundedSize, AllocatorStack::new); - RegisteredMemory result = stack.get(); - result.getBuffer().position(0).limit(size); - return result; - } - - public void put(RegisteredMemory memory) { - AllocatorStack allocatorStack = allocStackMap.get(memory.getBuffer().capacity()); - if (allocatorStack != null) { - allocatorStack.put(memory); + + public void preAlocate() { + conf.preallocateBuffersMap().forEach((size, numBuffers) -> { + logger.debug("Pre allocating {} buffers of size {}", numBuffers, size); + AllocatorStack stack = new AllocatorStack(size); + allocStackMap.put(size, stack); + stack.preallocate(numBuffers); + }); } - } - - public void preAlocate() { - conf.preallocateBuffersMap().forEach((size, numBuffers) -> { - logger.debug("Pre allocating {} buffers of size {}", numBuffers, size); - AllocatorStack stack = new AllocatorStack(size); - allocStackMap.put(size, stack); - stack.preallocate(numBuffers); - }); - } } diff --git a/src/main/java/org/apache/spark/shuffle/ucx/memory/RegisteredMemory.java b/src/main/java/org/apache/spark/shuffle/ucx/memory/RegisteredMemory.java index 47be67ef..9f8fbd5c 100755 --- a/src/main/java/org/apache/spark/shuffle/ucx/memory/RegisteredMemory.java +++ b/src/main/java/org/apache/spark/shuffle/ucx/memory/RegisteredMemory.java @@ -12,32 +12,32 @@ * Keeps track on reference count to memory region. */ public class RegisteredMemory { - private static final Logger logger = LoggerFactory.getLogger(RegisteredMemory.class); + private static final Logger logger = LoggerFactory.getLogger(RegisteredMemory.class); - private final AtomicInteger refcount; - private final UcpMemory memory; - private final ByteBuffer buffer; + private final AtomicInteger refcount; + private final UcpMemory memory; + private final ByteBuffer buffer; - RegisteredMemory(AtomicInteger refcount, UcpMemory memory, ByteBuffer buffer) { - this.refcount = refcount; - this.memory = memory; - this.buffer = buffer; - } - - public ByteBuffer getBuffer() { - return buffer; - } + RegisteredMemory(AtomicInteger refcount, UcpMemory memory, ByteBuffer buffer) { + this.refcount = refcount; + this.memory = memory; + this.buffer = buffer; + } - AtomicInteger getRefCount() { - return refcount; - } + public ByteBuffer getBuffer() { + return buffer; + } - void deregisterNativeMemory() { - if (refcount.get() != 0) { - logger.warn("De-registering memory of size {} that has active references.", buffer.capacity()); + AtomicInteger getRefCount() { + return refcount; } - if (memory != null && memory.getNativeId() != null) { - memory.deregister(); + + void deregisterNativeMemory() { + if (refcount.get() != 0) { + logger.warn("De-registering memory of size {} that has active references.", buffer.capacity()); + } + if (memory != null && memory.getNativeId() != null) { + memory.deregister(); + } } - } } diff --git a/src/main/java/org/apache/spark/shuffle/ucx/reducer/OnBlocksFetchCallback.java b/src/main/java/org/apache/spark/shuffle/ucx/reducer/OnBlocksFetchCallback.java index 57f2fe83..5938e81e 100755 --- a/src/main/java/org/apache/spark/shuffle/ucx/reducer/OnBlocksFetchCallback.java +++ b/src/main/java/org/apache/spark/shuffle/ucx/reducer/OnBlocksFetchCallback.java @@ -20,39 +20,39 @@ * Notifies Spark's shuffleFetchIterator on block fetch completion. */ public class OnBlocksFetchCallback extends ReducerCallback { - protected RegisteredMemory blocksMemory; - protected int[] sizes; + protected RegisteredMemory blocksMemory; + protected int[] sizes; - public OnBlocksFetchCallback(ReducerCallback callback, RegisteredMemory blocksMemory, int[] sizes) { - super(callback); - this.blocksMemory = blocksMemory; - this.sizes = sizes; - } + public OnBlocksFetchCallback(ReducerCallback callback, RegisteredMemory blocksMemory, int[] sizes) { + super(callback); + this.blocksMemory = blocksMemory; + this.sizes = sizes; + } - @Override - public void onSuccess(UcpRequest request) { - int position = 0; - AtomicInteger refCount = new AtomicInteger(blockIds.length); - for (int i = 0; i < blockIds.length; i++) { - BlockId block = blockIds[i]; - // Blocks are fetched to contiguous buffer. - // |----block1---||---block2---||---block3---| - // Slice each block to avoid buffer copy. - blocksMemory.getBuffer().position(position).limit(position + sizes[i]); - ByteBuffer blockBuffer = blocksMemory.getBuffer().slice(); - position += sizes[i]; - // Pass block to Spark's ShuffleFetchIterator. - listener.onBlockFetchSuccess(block.name(), new NioManagedBuffer(blockBuffer) { - @Override - public ManagedBuffer release() { - if (refCount.decrementAndGet() == 0) { - mempool.put(blocksMemory); - } - return this; + @Override + public void onSuccess(UcpRequest request) { + int position = 0; + AtomicInteger refCount = new AtomicInteger(blockIds.length); + for (int i = 0; i < blockIds.length; i++) { + BlockId block = blockIds[i]; + // Blocks are fetched to contiguous buffer. + // |----block1---||---block2---||---block3---| + // Slice each block to avoid buffer copy. + blocksMemory.getBuffer().position(position).limit(position + sizes[i]); + ByteBuffer blockBuffer = blocksMemory.getBuffer().slice(); + position += sizes[i]; + // Pass block to Spark's ShuffleFetchIterator. + listener.onBlockFetchSuccess(block.name(), new NioManagedBuffer(blockBuffer) { + @Override + public ManagedBuffer release() { + if (refCount.decrementAndGet() == 0) { + mempool.put(blocksMemory); + } + return this; + } + }); } - }); + logger.info("Endpoint {} fetched {} blocks of total size {} in {}ms", endpoint.getNativeId(), blockIds.length, + Utils.bytesToString(position), System.currentTimeMillis() - startTime); } - logger.info("Endpoint {} fetched {} blocks of total size {} in {}ms", endpoint.getNativeId(), blockIds.length, - Utils.bytesToString(position), System.currentTimeMillis() - startTime); - } } diff --git a/src/main/java/org/apache/spark/shuffle/ucx/reducer/ReducerCallback.java b/src/main/java/org/apache/spark/shuffle/ucx/reducer/ReducerCallback.java index 612ce9c3..1d722539 100755 --- a/src/main/java/org/apache/spark/shuffle/ucx/reducer/ReducerCallback.java +++ b/src/main/java/org/apache/spark/shuffle/ucx/reducer/ReducerCallback.java @@ -18,25 +18,25 @@ * Common data needed for offset fetch callback and subsequent block fetch callback. */ public abstract class ReducerCallback extends UcxCallback { - protected MemoryPool mempool; - protected BlockId[] blockIds; - protected UcpEndpoint endpoint; - protected BlockFetchingListener listener; - protected static final Logger logger = LoggerFactory.getLogger(ReducerCallback.class); - protected long startTime = System.currentTimeMillis(); + protected MemoryPool mempool; + protected BlockId[] blockIds; + protected UcpEndpoint endpoint; + protected BlockFetchingListener listener; + protected static final Logger logger = LoggerFactory.getLogger(ReducerCallback.class); + protected long startTime = System.currentTimeMillis(); - public ReducerCallback(BlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener) { - this.mempool = ((CommonUcxShuffleManager)SparkEnv.get().shuffleManager()).ucxNode().getMemoryPool(); - this.blockIds = blockIds; - this.endpoint = endpoint; - this.listener = listener; - } + public ReducerCallback(BlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener) { + this.mempool = ((CommonUcxShuffleManager) SparkEnv.get().shuffleManager()).ucxNode().getMemoryPool(); + this.blockIds = blockIds; + this.endpoint = endpoint; + this.listener = listener; + } - public ReducerCallback(ReducerCallback callback) { - this.blockIds = callback.blockIds; - this.endpoint = callback.endpoint; - this.listener = callback.listener; - this.mempool = callback.mempool; - this.startTime = callback.startTime; - } + public ReducerCallback(ReducerCallback callback) { + this.blockIds = callback.blockIds; + this.endpoint = callback.endpoint; + this.listener = callback.listener; + this.mempool = callback.mempool; + this.startTime = callback.startTime; + } } diff --git a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_1/OnOffsetsFetchCallback.java b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_1/OnOffsetsFetchCallback.java index 5f95caf6..93381610 100755 --- a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_1/OnOffsetsFetchCallback.java +++ b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_1/OnOffsetsFetchCallback.java @@ -22,49 +22,49 @@ * Callback, called when got all offsets for blocks */ public class OnOffsetsFetchCallback extends ReducerCallback { - private final RegisteredMemory offsetMemory; - private final long[] dataAddresses; - private Map dataRkeysCache; + private final RegisteredMemory offsetMemory; + private final long[] dataAddresses; + private Map dataRkeysCache; - public OnOffsetsFetchCallback(ShuffleBlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener, - RegisteredMemory offsetMemory, long[] dataAddresses, - Map dataRkeysCache) { - super(blockIds, endpoint, listener); - this.offsetMemory = offsetMemory; - this.dataAddresses = dataAddresses; - this.dataRkeysCache = dataRkeysCache; - } - - @Override - public void onSuccess(UcpRequest request) { - ByteBuffer resultOffset = offsetMemory.getBuffer(); - long totalSize = 0; - int[] sizes = new int[blockIds.length]; - int offsetSize = UnsafeUtils.LONG_SIZE; - for (int i = 0; i < blockIds.length; i++) { - // Blocks in metadata buffer are in form | blockOffsetStart | blockOffsetEnd | - long blockOffset = resultOffset.getLong(i * 2 * offsetSize); - long blockLength = resultOffset.getLong(i * 2 * offsetSize + offsetSize) - blockOffset; - assert (blockLength > 0) && (blockLength <= Integer.MAX_VALUE); - sizes[i] = (int) blockLength; - totalSize += blockLength; - dataAddresses[i] += blockOffset; + public OnOffsetsFetchCallback(ShuffleBlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener, + RegisteredMemory offsetMemory, long[] dataAddresses, + Map dataRkeysCache) { + super(blockIds, endpoint, listener); + this.offsetMemory = offsetMemory; + this.dataAddresses = dataAddresses; + this.dataRkeysCache = dataRkeysCache; } - assert (totalSize > 0) && (totalSize < Integer.MAX_VALUE); - mempool.put(offsetMemory); - RegisteredMemory blocksMemory = mempool.get((int) totalSize); + @Override + public void onSuccess(UcpRequest request) { + ByteBuffer resultOffset = offsetMemory.getBuffer(); + long totalSize = 0; + int[] sizes = new int[blockIds.length]; + int offsetSize = UnsafeUtils.LONG_SIZE; + for (int i = 0; i < blockIds.length; i++) { + // Blocks in metadata buffer are in form | blockOffsetStart | blockOffsetEnd | + long blockOffset = resultOffset.getLong(i * 2 * offsetSize); + long blockLength = resultOffset.getLong(i * 2 * offsetSize + offsetSize) - blockOffset; + assert (blockLength > 0) && (blockLength <= Integer.MAX_VALUE); + sizes[i] = (int) blockLength; + totalSize += blockLength; + dataAddresses[i] += blockOffset; + } - long offset = 0; - // Submits N fetch blocks requests - for (int i = 0; i < blockIds.length; i++) { - endpoint.getNonBlockingImplicit(dataAddresses[i], dataRkeysCache.get(((ShuffleBlockId)blockIds[i]).mapId()), - UcxUtils.getAddress(blocksMemory.getBuffer()) + offset, sizes[i]); - offset += sizes[i]; - } + assert (totalSize > 0) && (totalSize < Integer.MAX_VALUE); + mempool.put(offsetMemory); + RegisteredMemory blocksMemory = mempool.get((int) totalSize); - // Process blocks when all fetched. - // Flush guarantees that callback would invoke when all fetch requests will completed. - endpoint.flushNonBlocking(new OnBlocksFetchCallback(this, blocksMemory, sizes)); - } + long offset = 0; + // Submits N fetch blocks requests + for (int i = 0; i < blockIds.length; i++) { + endpoint.getNonBlockingImplicit(dataAddresses[i], dataRkeysCache.get(((ShuffleBlockId) blockIds[i]).mapId()), + UcxUtils.getAddress(blocksMemory.getBuffer()) + offset, sizes[i]); + offset += sizes[i]; + } + + // Process blocks when all fetched. + // Flush guarantees that callback would invoke when all fetch requests will completed. + endpoint.flushNonBlocking(new OnBlocksFetchCallback(this, blocksMemory, sizes)); + } } diff --git a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_1/UcxShuffleClient.java b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_1/UcxShuffleClient.java index 0595bfc8..4d577b75 100755 --- a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_1/UcxShuffleClient.java +++ b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_1/UcxShuffleClient.java @@ -28,84 +28,83 @@ import java.util.HashMap; public class UcxShuffleClient extends ShuffleClient { - private final MemoryPool mempool; - private static final Logger logger = LoggerFactory.getLogger(UcxShuffleClient.class); - private final UcxShuffleManager ucxShuffleManager; - private final TempShuffleReadMetrics shuffleReadMetrics; - private final UcxWorkerWrapper workerWrapper; - final HashMap offsetRkeysCache = new HashMap<>(); - final HashMap dataRkeysCache = new HashMap<>(); - - public UcxShuffleClient(TempShuffleReadMetrics shuffleReadMetrics, - UcxWorkerWrapper workerWrapper) { - this.ucxShuffleManager = (UcxShuffleManager) SparkEnv.get().shuffleManager(); - this.mempool = ucxShuffleManager.ucxNode().getMemoryPool(); - this.shuffleReadMetrics = shuffleReadMetrics; - this.workerWrapper = workerWrapper; - } - - /** - * Submits n non blocking fetch offsets to get needed offsets for n blocks. - */ - private void submitFetchOffsets(UcpEndpoint endpoint, ShuffleBlockId[] blockIds, - long[] dataAddresses, RegisteredMemory offsetMemory) { - DriverMetadata driverMetadata = workerWrapper.fetchDriverMetadataBuffer(blockIds[0].shuffleId()); - for (int i = 0; i < blockIds.length; i++) { - ShuffleBlockId blockId = blockIds[i]; - - long offsetAddress = driverMetadata.offsetAddress(blockId.mapId()); - dataAddresses[i] = driverMetadata.dataAddress(blockId.mapId()); - - offsetRkeysCache.computeIfAbsent(blockId.mapId(), mapId -> - endpoint.unpackRemoteKey(driverMetadata.offsetRkey(blockId.mapId()))); - - dataRkeysCache.computeIfAbsent(blockId.mapId(), mapId -> - endpoint.unpackRemoteKey(driverMetadata.dataRkey(blockId.mapId()))); - - endpoint.getNonBlockingImplicit( - offsetAddress + blockId.reduceId() * UnsafeUtils.LONG_SIZE, - offsetRkeysCache.get(blockId.mapId()), - UcxUtils.getAddress(offsetMemory.getBuffer()) + (i * 2L * UnsafeUtils.LONG_SIZE), - 2L * UnsafeUtils.LONG_SIZE); + private final MemoryPool mempool; + private static final Logger logger = LoggerFactory.getLogger(UcxShuffleClient.class); + private final UcxShuffleManager ucxShuffleManager; + private final TempShuffleReadMetrics shuffleReadMetrics; + private final UcxWorkerWrapper workerWrapper; + final HashMap offsetRkeysCache = new HashMap<>(); + final HashMap dataRkeysCache = new HashMap<>(); + + public UcxShuffleClient(TempShuffleReadMetrics shuffleReadMetrics, + UcxWorkerWrapper workerWrapper) { + this.ucxShuffleManager = (UcxShuffleManager) SparkEnv.get().shuffleManager(); + this.mempool = ucxShuffleManager.ucxNode().getMemoryPool(); + this.shuffleReadMetrics = shuffleReadMetrics; + this.workerWrapper = workerWrapper; + } + + /** + * Submits n non blocking fetch offsets to get needed offsets for n blocks. + */ + private void submitFetchOffsets(UcpEndpoint endpoint, ShuffleBlockId[] blockIds, + long[] dataAddresses, RegisteredMemory offsetMemory) { + DriverMetadata driverMetadata = workerWrapper.fetchDriverMetadataBuffer(blockIds[0].shuffleId()); + for (int i = 0; i < blockIds.length; i++) { + ShuffleBlockId blockId = blockIds[i]; + + long offsetAddress = driverMetadata.offsetAddress(blockId.mapId()); + dataAddresses[i] = driverMetadata.dataAddress(blockId.mapId()); + + offsetRkeysCache.computeIfAbsent(blockId.mapId(), mapId -> + endpoint.unpackRemoteKey(driverMetadata.offsetRkey(blockId.mapId()))); + + dataRkeysCache.computeIfAbsent(blockId.mapId(), mapId -> + endpoint.unpackRemoteKey(driverMetadata.dataRkey(blockId.mapId()))); + + endpoint.getNonBlockingImplicit( + offsetAddress + blockId.reduceId() * UnsafeUtils.LONG_SIZE, + offsetRkeysCache.get(blockId.mapId()), + UcxUtils.getAddress(offsetMemory.getBuffer()) + (i * 2L * UnsafeUtils.LONG_SIZE), + 2L * UnsafeUtils.LONG_SIZE); + } + } + + /** + * Reducer entry point. Fetches remote blocks, using 2 ucp_get calls. + * This method is inside ShuffleFetchIterator's for loop over hosts. + * First fetches block offset from index file, and then fetches block itself. + */ + @Override + public void fetchBlocks(String host, int port, String execId, + String[] blockIds, BlockFetchingListener listener) { + long startTime = System.currentTimeMillis(); + + BlockManagerId blockManagerId = BlockManagerId.apply(execId, host, port, Option.empty()); + UcpEndpoint endpoint = workerWrapper.getConnection(blockManagerId); + + long[] dataAddresses = new long[blockIds.length]; + + // Need to fetch 2 long offsets current block + next block to calculate exact block size. + RegisteredMemory offsetMemory = mempool.get(2 * UnsafeUtils.LONG_SIZE * blockIds.length); + + ShuffleBlockId[] shuffleBlockIds = Arrays.stream(blockIds) + .map(blockId -> (ShuffleBlockId) BlockId.apply(blockId)).toArray(ShuffleBlockId[]::new); + + // Submits N implicit get requests without callback + submitFetchOffsets(endpoint, shuffleBlockIds, dataAddresses, offsetMemory); + + // flush guarantees that all that requests completes when callback is called. + endpoint.flushNonBlocking( + new OnOffsetsFetchCallback(shuffleBlockIds, endpoint, listener, offsetMemory, + dataAddresses, dataRkeysCache)); + shuffleReadMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime); + } + + @Override + public void close() { + offsetRkeysCache.values().forEach(UcpRemoteKey::close); + dataRkeysCache.values().forEach(UcpRemoteKey::close); + logger.info("Shuffle read metrics, fetch wait time: {}ms", shuffleReadMetrics.fetchWaitTime()); } - } - - /** - * Reducer entry point. Fetches remote blocks, using 2 ucp_get calls. - * This method is inside ShuffleFetchIterator's for loop over hosts. - * First fetches block offset from index file, and then fetches block itself. - */ - @Override - public void fetchBlocks(String host, int port, String execId, - String[] blockIds, BlockFetchingListener listener) { - long startTime = System.currentTimeMillis(); - - BlockManagerId blockManagerId = BlockManagerId.apply(execId, host, port, Option.empty()); - UcpEndpoint endpoint = workerWrapper.getConnection(blockManagerId); - - long[] dataAddresses = new long[blockIds.length]; - - // Need to fetch 2 long offsets current block + next block to calculate exact block size. - RegisteredMemory offsetMemory = mempool.get(2 * UnsafeUtils.LONG_SIZE * blockIds.length); - - ShuffleBlockId[] shuffleBlockIds = Arrays.stream(blockIds) - .map(blockId -> (ShuffleBlockId) BlockId.apply(blockId)).toArray(ShuffleBlockId[]::new); - - // Submits N implicit get requests without callback - submitFetchOffsets(endpoint, shuffleBlockIds, dataAddresses, offsetMemory); - - // flush guarantees that all that requests completes when callback is called. - // TODO: fix https://github.com/openucx/ucx/issues/4267 and use endpoint flush. - workerWrapper.worker().flushNonBlocking( - new OnOffsetsFetchCallback(shuffleBlockIds, endpoint, listener, offsetMemory, - dataAddresses, dataRkeysCache)); - shuffleReadMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime); - } - - @Override - public void close() { - offsetRkeysCache.values().forEach(UcpRemoteKey::close); - dataRkeysCache.values().forEach(UcpRemoteKey::close); - logger.info("Shuffle read metrics, fetch wait time: {}ms", shuffleReadMetrics.fetchWaitTime()); - } } diff --git a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_4/OnOffsetsFetchCallback.java b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_4/OnOffsetsFetchCallback.java index d04e541c..86810d22 100755 --- a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_4/OnOffsetsFetchCallback.java +++ b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_4/OnOffsetsFetchCallback.java @@ -22,49 +22,49 @@ * Callback, called when got all offsets for blocks */ public class OnOffsetsFetchCallback extends ReducerCallback { - private final RegisteredMemory offsetMemory; - private final long[] dataAddresses; - private Map dataRkeysCache; + private final RegisteredMemory offsetMemory; + private final long[] dataAddresses; + private Map dataRkeysCache; - public OnOffsetsFetchCallback(ShuffleBlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener, - RegisteredMemory offsetMemory, long[] dataAddresses, - Map dataRkeysCache) { - super(blockIds, endpoint, listener); - this.offsetMemory = offsetMemory; - this.dataAddresses = dataAddresses; - this.dataRkeysCache = dataRkeysCache; - } - - @Override - public void onSuccess(UcpRequest request) { - ByteBuffer resultOffset = offsetMemory.getBuffer(); - long totalSize = 0; - int[] sizes = new int[blockIds.length]; - int offsetSize = UnsafeUtils.LONG_SIZE; - for (int i = 0; i < blockIds.length; i++) { - // Blocks in metadata buffer are in form | blockOffsetStart | blockOffsetEnd | - long blockOffset = resultOffset.getLong(i * 2 * offsetSize); - long blockLength = resultOffset.getLong(i * 2 * offsetSize + offsetSize) - blockOffset; - assert (blockLength > 0) && (blockLength <= Integer.MAX_VALUE); - sizes[i] = (int) blockLength; - totalSize += blockLength; - dataAddresses[i] += blockOffset; + public OnOffsetsFetchCallback(ShuffleBlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener, + RegisteredMemory offsetMemory, long[] dataAddresses, + Map dataRkeysCache) { + super(blockIds, endpoint, listener); + this.offsetMemory = offsetMemory; + this.dataAddresses = dataAddresses; + this.dataRkeysCache = dataRkeysCache; } - assert (totalSize > 0) && (totalSize < Integer.MAX_VALUE); - mempool.put(offsetMemory); - RegisteredMemory blocksMemory = mempool.get((int) totalSize); + @Override + public void onSuccess(UcpRequest request) { + ByteBuffer resultOffset = offsetMemory.getBuffer(); + long totalSize = 0; + int[] sizes = new int[blockIds.length]; + int offsetSize = UnsafeUtils.LONG_SIZE; + for (int i = 0; i < blockIds.length; i++) { + // Blocks in metadata buffer are in form | blockOffsetStart | blockOffsetEnd | + long blockOffset = resultOffset.getLong(i * 2 * offsetSize); + long blockLength = resultOffset.getLong(i * 2 * offsetSize + offsetSize) - blockOffset; + assert (blockLength > 0) && (blockLength <= Integer.MAX_VALUE); + sizes[i] = (int) blockLength; + totalSize += blockLength; + dataAddresses[i] += blockOffset; + } - long offset = 0; - // Submits N fetch blocks requests - for (int i = 0; i < blockIds.length; i++) { - endpoint.getNonBlockingImplicit(dataAddresses[i], dataRkeysCache.get(((ShuffleBlockId)blockIds[i]).mapId()), - UcxUtils.getAddress(blocksMemory.getBuffer()) + offset, sizes[i]); - offset += sizes[i]; - } + assert (totalSize > 0) && (totalSize < Integer.MAX_VALUE); + mempool.put(offsetMemory); + RegisteredMemory blocksMemory = mempool.get((int) totalSize); - // Process blocks when all fetched. - // Flush guarantees that callback would invoke when all fetch requests will completed. - endpoint.flushNonBlocking(new OnBlocksFetchCallback(this, blocksMemory, sizes)); - } + long offset = 0; + // Submits N fetch blocks requests + for (int i = 0; i < blockIds.length; i++) { + endpoint.getNonBlockingImplicit(dataAddresses[i], dataRkeysCache.get(((ShuffleBlockId) blockIds[i]).mapId()), + UcxUtils.getAddress(blocksMemory.getBuffer()) + offset, sizes[i]); + offset += sizes[i]; + } + + // Process blocks when all fetched. + // Flush guarantees that callback would invoke when all fetch requests will completed. + endpoint.flushNonBlocking(new OnBlocksFetchCallback(this, blocksMemory, sizes)); + } } diff --git a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_4/UcxShuffleClient.java b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_4/UcxShuffleClient.java index e2c499c5..4036b639 100755 --- a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_4/UcxShuffleClient.java +++ b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_4/UcxShuffleClient.java @@ -27,85 +27,84 @@ import java.util.HashMap; public class UcxShuffleClient extends ShuffleClient { - private final MemoryPool mempool; - private static final Logger logger = LoggerFactory.getLogger(UcxShuffleClient.class); - private final UcxShuffleManager ucxShuffleManager; - private final TempShuffleReadMetrics shuffleReadMetrics; - private final UcxWorkerWrapper workerWrapper; - final HashMap offsetRkeysCache = new HashMap<>(); - final HashMap dataRkeysCache = new HashMap<>(); - - public UcxShuffleClient(TempShuffleReadMetrics shuffleReadMetrics, - UcxWorkerWrapper workerWrapper) { - this.ucxShuffleManager = (UcxShuffleManager) SparkEnv.get().shuffleManager(); - this.mempool = ucxShuffleManager.ucxNode().getMemoryPool(); - this.shuffleReadMetrics = shuffleReadMetrics; - this.workerWrapper = workerWrapper; - } - - /** - * Submits n non blocking fetch offsets to get needed offsets for n blocks. - */ - private void submitFetchOffsets(UcpEndpoint endpoint, ShuffleBlockId[] blockIds, - long[] dataAddresses, RegisteredMemory offsetMemory) { - DriverMetadata driverMetadata = workerWrapper.fetchDriverMetadataBuffer(blockIds[0].shuffleId()); - for (int i = 0; i < blockIds.length; i++) { - ShuffleBlockId blockId = blockIds[i]; - - long offsetAddress = driverMetadata.offsetAddress(blockId.mapId()); - dataAddresses[i] = driverMetadata.dataAddress(blockId.mapId()); - - offsetRkeysCache.computeIfAbsent(blockId.mapId(), mapId -> - endpoint.unpackRemoteKey(driverMetadata.offsetRkey(blockId.mapId()))); - - dataRkeysCache.computeIfAbsent(blockId.mapId(), mapId -> - endpoint.unpackRemoteKey(driverMetadata.dataRkey(blockId.mapId()))); - - endpoint.getNonBlockingImplicit( - offsetAddress + blockId.reduceId() * UnsafeUtils.LONG_SIZE, - offsetRkeysCache.get(blockId.mapId()), - UcxUtils.getAddress(offsetMemory.getBuffer()) + (i * 2L * UnsafeUtils.LONG_SIZE), - 2L * UnsafeUtils.LONG_SIZE); + private final MemoryPool mempool; + private static final Logger logger = LoggerFactory.getLogger(UcxShuffleClient.class); + private final UcxShuffleManager ucxShuffleManager; + private final TempShuffleReadMetrics shuffleReadMetrics; + private final UcxWorkerWrapper workerWrapper; + final HashMap offsetRkeysCache = new HashMap<>(); + final HashMap dataRkeysCache = new HashMap<>(); + + public UcxShuffleClient(TempShuffleReadMetrics shuffleReadMetrics, + UcxWorkerWrapper workerWrapper) { + this.ucxShuffleManager = (UcxShuffleManager) SparkEnv.get().shuffleManager(); + this.mempool = ucxShuffleManager.ucxNode().getMemoryPool(); + this.shuffleReadMetrics = shuffleReadMetrics; + this.workerWrapper = workerWrapper; + } + + /** + * Submits n non blocking fetch offsets to get needed offsets for n blocks. + */ + private void submitFetchOffsets(UcpEndpoint endpoint, ShuffleBlockId[] blockIds, + long[] dataAddresses, RegisteredMemory offsetMemory) { + DriverMetadata driverMetadata = workerWrapper.fetchDriverMetadataBuffer(blockIds[0].shuffleId()); + for (int i = 0; i < blockIds.length; i++) { + ShuffleBlockId blockId = blockIds[i]; + + long offsetAddress = driverMetadata.offsetAddress(blockId.mapId()); + dataAddresses[i] = driverMetadata.dataAddress(blockId.mapId()); + + offsetRkeysCache.computeIfAbsent(blockId.mapId(), mapId -> + endpoint.unpackRemoteKey(driverMetadata.offsetRkey(blockId.mapId()))); + + dataRkeysCache.computeIfAbsent(blockId.mapId(), mapId -> + endpoint.unpackRemoteKey(driverMetadata.dataRkey(blockId.mapId()))); + + endpoint.getNonBlockingImplicit( + offsetAddress + blockId.reduceId() * UnsafeUtils.LONG_SIZE, + offsetRkeysCache.get(blockId.mapId()), + UcxUtils.getAddress(offsetMemory.getBuffer()) + (i * 2L * UnsafeUtils.LONG_SIZE), + 2L * UnsafeUtils.LONG_SIZE); + } + } + + /** + * Reducer entry point. Fetches remote blocks, using 2 ucp_get calls. + * This method is inside ShuffleFetchIterator's for loop over hosts. + * First fetches block offset from index file, and then fetches block itself. + */ + @Override + public void fetchBlocks(String host, int port, String execId, + String[] blockIds, BlockFetchingListener listener, + DownloadFileManager downloadFileManager) { + long startTime = System.currentTimeMillis(); + + BlockManagerId blockManagerId = BlockManagerId.apply(execId, host, port, Option.empty()); + UcpEndpoint endpoint = workerWrapper.getConnection(blockManagerId); + + long[] dataAddresses = new long[blockIds.length]; + + // Need to fetch 2 long offsets current block + next block to calculate exact block size. + RegisteredMemory offsetMemory = mempool.get(2 * UnsafeUtils.LONG_SIZE * blockIds.length); + + ShuffleBlockId[] shuffleBlockIds = Arrays.stream(blockIds) + .map(blockId -> (ShuffleBlockId) BlockId.apply(blockId)).toArray(ShuffleBlockId[]::new); + + // Submits N implicit get requests without callback + submitFetchOffsets(endpoint, shuffleBlockIds, dataAddresses, offsetMemory); + + // flush guarantees that all that requests completes when callback is called. + endpoint.flushNonBlocking( + new OnOffsetsFetchCallback(shuffleBlockIds, endpoint, listener, offsetMemory, + dataAddresses, dataRkeysCache)); + shuffleReadMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime); + } + + @Override + public void close() { + offsetRkeysCache.values().forEach(UcpRemoteKey::close); + dataRkeysCache.values().forEach(UcpRemoteKey::close); + logger.info("Shuffle read metrics, fetch wait time: {}ms", shuffleReadMetrics.fetchWaitTime()); } - } - - /** - * Reducer entry point. Fetches remote blocks, using 2 ucp_get calls. - * This method is inside ShuffleFetchIterator's for loop over hosts. - * First fetches block offset from index file, and then fetches block itself. - */ - @Override - public void fetchBlocks(String host, int port, String execId, - String[] blockIds, BlockFetchingListener listener, - DownloadFileManager downloadFileManager) { - long startTime = System.currentTimeMillis(); - - BlockManagerId blockManagerId = BlockManagerId.apply(execId, host, port, Option.empty()); - UcpEndpoint endpoint = workerWrapper.getConnection(blockManagerId); - - long[] dataAddresses = new long[blockIds.length]; - - // Need to fetch 2 long offsets current block + next block to calculate exact block size. - RegisteredMemory offsetMemory = mempool.get(2 * UnsafeUtils.LONG_SIZE * blockIds.length); - - ShuffleBlockId[] shuffleBlockIds = Arrays.stream(blockIds) - .map(blockId -> (ShuffleBlockId) BlockId.apply(blockId)).toArray(ShuffleBlockId[]::new); - - // Submits N implicit get requests without callback - submitFetchOffsets(endpoint, shuffleBlockIds, dataAddresses, offsetMemory); - - // flush guarantees that all that requests completes when callback is called. - // TODO: fix https://github.com/openucx/ucx/issues/4267 and use endpoint flush. - workerWrapper.worker().flushNonBlocking( - new OnOffsetsFetchCallback(shuffleBlockIds, endpoint, listener, offsetMemory, - dataAddresses, dataRkeysCache)); - shuffleReadMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime); - } - - @Override - public void close() { - offsetRkeysCache.values().forEach(UcpRemoteKey::close); - dataRkeysCache.values().forEach(UcpRemoteKey::close); - logger.info("Shuffle read metrics, fetch wait time: {}ms", shuffleReadMetrics.fetchWaitTime()); - } } diff --git a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_0/OnOffsetsFetchCallback.java b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_0/OnOffsetsFetchCallback.java index 372955b5..e792e6db 100755 --- a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_0/OnOffsetsFetchCallback.java +++ b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_0/OnOffsetsFetchCallback.java @@ -25,69 +25,69 @@ * Callback, called when got all offsets for blocks */ public class OnOffsetsFetchCallback extends ReducerCallback { - private final RegisteredMemory offsetMemory; - private final long[] dataAddresses; - private Map dataRkeysCache; - private final Map mapId2PartitionId; + private final RegisteredMemory offsetMemory; + private final long[] dataAddresses; + private Map dataRkeysCache; + private final Map mapId2PartitionId; - public OnOffsetsFetchCallback(BlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener, - RegisteredMemory offsetMemory, long[] dataAddresses, - Map dataRkeysCache, - Map mapId2PartitionId) { - super(blockIds, endpoint, listener); - this.offsetMemory = offsetMemory; - this.dataAddresses = dataAddresses; - this.dataRkeysCache = dataRkeysCache; - this.mapId2PartitionId = mapId2PartitionId; - } + public OnOffsetsFetchCallback(BlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener, + RegisteredMemory offsetMemory, long[] dataAddresses, + Map dataRkeysCache, + Map mapId2PartitionId) { + super(blockIds, endpoint, listener); + this.offsetMemory = offsetMemory; + this.dataAddresses = dataAddresses; + this.dataRkeysCache = dataRkeysCache; + this.mapId2PartitionId = mapId2PartitionId; + } - @Override - public void onSuccess(UcpRequest request) { - ByteBuffer resultOffset = offsetMemory.getBuffer(); - long totalSize = 0; - int[] sizes = new int[blockIds.length]; - int offset = 0; - long blockOffset; - long blockLength; - int offsetSize = UnsafeUtils.LONG_SIZE; - for (int i = 0; i < blockIds.length; i++) { - // Blocks in metadata buffer are in form | blockOffsetStart | blockOffsetEnd | - if (blockIds[i] instanceof ShuffleBlockBatchId) { - ShuffleBlockBatchId blockBatchId = (ShuffleBlockBatchId) blockIds[i]; - int blocksInBatch = blockBatchId.endReduceId() - blockBatchId.startReduceId(); - blockOffset = resultOffset.getLong(offset * 2 * offsetSize); - blockLength = resultOffset.getLong(offset * 2 * offsetSize + offsetSize * blocksInBatch) - - blockOffset; - offset += blocksInBatch; - } else { - blockOffset = resultOffset.getLong(offset * 16); - blockLength = resultOffset.getLong(offset * 16 + 8) - blockOffset; - offset++; - } + @Override + public void onSuccess(UcpRequest request) { + ByteBuffer resultOffset = offsetMemory.getBuffer(); + long totalSize = 0; + int[] sizes = new int[blockIds.length]; + int offset = 0; + long blockOffset; + long blockLength; + int offsetSize = UnsafeUtils.LONG_SIZE; + for (int i = 0; i < blockIds.length; i++) { + // Blocks in metadata buffer are in form | blockOffsetStart | blockOffsetEnd | + if (blockIds[i] instanceof ShuffleBlockBatchId) { + ShuffleBlockBatchId blockBatchId = (ShuffleBlockBatchId) blockIds[i]; + int blocksInBatch = blockBatchId.endReduceId() - blockBatchId.startReduceId(); + blockOffset = resultOffset.getLong(offset * 2 * offsetSize); + blockLength = resultOffset.getLong(offset * 2 * offsetSize + offsetSize * blocksInBatch) + - blockOffset; + offset += blocksInBatch; + } else { + blockOffset = resultOffset.getLong(offset * 16); + blockLength = resultOffset.getLong(offset * 16 + 8) - blockOffset; + offset++; + } - assert (blockLength > 0) && (blockLength <= Integer.MAX_VALUE); - sizes[i] = (int) blockLength; - totalSize += blockLength; - dataAddresses[i] += blockOffset; - } + assert (blockLength > 0) && (blockLength <= Integer.MAX_VALUE); + sizes[i] = (int) blockLength; + totalSize += blockLength; + dataAddresses[i] += blockOffset; + } - assert (totalSize > 0) && (totalSize < Integer.MAX_VALUE); - mempool.put(offsetMemory); - RegisteredMemory blocksMemory = mempool.get((int) totalSize); + assert (totalSize > 0) && (totalSize < Integer.MAX_VALUE); + mempool.put(offsetMemory); + RegisteredMemory blocksMemory = mempool.get((int) totalSize); - offset = 0; - // Submits N fetch blocks requests - for (int i = 0; i < blockIds.length; i++) { - int mapPartitionId = (blockIds[i] instanceof ShuffleBlockId) ? - mapId2PartitionId.get(((ShuffleBlockId)blockIds[i]).mapId()) : - mapId2PartitionId.get(((ShuffleBlockBatchId)blockIds[i]).mapId()); - endpoint.getNonBlockingImplicit(dataAddresses[i], dataRkeysCache.get(mapPartitionId), - UcxUtils.getAddress(blocksMemory.getBuffer()) + offset, sizes[i]); - offset += sizes[i]; - } + offset = 0; + // Submits N fetch blocks requests + for (int i = 0; i < blockIds.length; i++) { + int mapPartitionId = (blockIds[i] instanceof ShuffleBlockId) ? + mapId2PartitionId.get(((ShuffleBlockId) blockIds[i]).mapId()) : + mapId2PartitionId.get(((ShuffleBlockBatchId) blockIds[i]).mapId()); + endpoint.getNonBlockingImplicit(dataAddresses[i], dataRkeysCache.get(mapPartitionId), + UcxUtils.getAddress(blocksMemory.getBuffer()) + offset, sizes[i]); + offset += sizes[i]; + } - // Process blocks when all fetched. - // Flush guarantees that callback would invoke when all fetch requests will completed. - endpoint.flushNonBlocking(new OnBlocksFetchCallback(this, blocksMemory, sizes)); - } + // Process blocks when all fetched. + // Flush guarantees that callback would invoke when all fetch requests will completed. + endpoint.flushNonBlocking(new OnBlocksFetchCallback(this, blocksMemory, sizes)); + } } diff --git a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_0/UcxShuffleClient.java b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_0/UcxShuffleClient.java index 77a3d87c..c7d1360e 100755 --- a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_0/UcxShuffleClient.java +++ b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_0/UcxShuffleClient.java @@ -27,110 +27,109 @@ import java.util.Map; public class UcxShuffleClient extends BlockStoreClient { - private static final Logger logger = LoggerFactory.getLogger(UcxShuffleClient.class); - private final UcxWorkerWrapper workerWrapper; - private final Map mapId2PartitionId; - private final TempShuffleReadMetrics shuffleReadMetrics; - private final int shuffleId; - final HashMap offsetRkeysCache = new HashMap<>(); - final HashMap dataRkeysCache = new HashMap<>(); - - - public UcxShuffleClient(int shuffleId, UcxWorkerWrapper workerWrapper, - Map mapId2PartitionId, TempShuffleReadMetrics shuffleReadMetrics) { - this.workerWrapper = workerWrapper; - this.shuffleId = shuffleId; - this.mapId2PartitionId = mapId2PartitionId; - this.shuffleReadMetrics = shuffleReadMetrics; - } - - /** - * Submits n non blocking fetch offsets to get needed offsets for n blocks. - */ - private void submitFetchOffsets(UcpEndpoint endpoint, BlockId[] blockIds, - RegisteredMemory offsetMemory, - long[] dataAddresses) { - DriverMetadata driverMetadata = workerWrapper.fetchDriverMetadataBuffer(shuffleId); - long offset = 0; - int startReduceId; - long size; - - for (int i = 0; i < blockIds.length; i++) { - BlockId blockId = blockIds[i]; - int mapIdpartition; - - if (blockId instanceof ShuffleBlockId) { - ShuffleBlockId shuffleBlockId = (ShuffleBlockId) blockId; - mapIdpartition = mapId2PartitionId.get(shuffleBlockId.mapId()); - size = 2L * UnsafeUtils.LONG_SIZE; - startReduceId = shuffleBlockId.reduceId(); - } else { - ShuffleBlockBatchId shuffleBlockBatchId = (ShuffleBlockBatchId) blockId; - mapIdpartition = mapId2PartitionId.get(shuffleBlockBatchId.mapId()); - size = (shuffleBlockBatchId.endReduceId() - shuffleBlockBatchId.startReduceId()) - * 2L * UnsafeUtils.LONG_SIZE; - startReduceId = shuffleBlockBatchId.startReduceId(); - } - - long offsetAddress = driverMetadata.offsetAddress(mapIdpartition); - dataAddresses[i] = driverMetadata.dataAddress(mapIdpartition); - - offsetRkeysCache.computeIfAbsent(mapIdpartition, mapId -> - endpoint.unpackRemoteKey(driverMetadata.offsetRkey(mapIdpartition))); - - dataRkeysCache.computeIfAbsent(mapIdpartition, mapId -> - endpoint.unpackRemoteKey(driverMetadata.dataRkey(mapIdpartition))); - - endpoint.getNonBlockingImplicit( - offsetAddress + startReduceId * UnsafeUtils.LONG_SIZE, - offsetRkeysCache.get(mapIdpartition), - UcxUtils.getAddress(offsetMemory.getBuffer()) + offset, - size); - - offset += size; + private static final Logger logger = LoggerFactory.getLogger(UcxShuffleClient.class); + private final UcxWorkerWrapper workerWrapper; + private final Map mapId2PartitionId; + private final TempShuffleReadMetrics shuffleReadMetrics; + private final int shuffleId; + final HashMap offsetRkeysCache = new HashMap<>(); + final HashMap dataRkeysCache = new HashMap<>(); + + + public UcxShuffleClient(int shuffleId, UcxWorkerWrapper workerWrapper, + Map mapId2PartitionId, TempShuffleReadMetrics shuffleReadMetrics) { + this.workerWrapper = workerWrapper; + this.shuffleId = shuffleId; + this.mapId2PartitionId = mapId2PartitionId; + this.shuffleReadMetrics = shuffleReadMetrics; } - } - - @Override - public void fetchBlocks(String host, int port, String execId, String[] blockIds, BlockFetchingListener listener, - DownloadFileManager downloadFileManager) { - long startTime = System.currentTimeMillis(); - BlockManagerId blockManagerId = BlockManagerId.apply(execId, host, port, Option.empty()); - UcpEndpoint endpoint = workerWrapper.getConnection(blockManagerId); - long[] dataAddresses = new long[blockIds.length]; - int totalBlocks = 0; - - BlockId[] blocks = new BlockId[blockIds.length]; - - for (int i = 0; i < blockIds.length; i++) { - blocks[i] = BlockId.apply(blockIds[i]); - if (blocks[i] instanceof ShuffleBlockId) { - totalBlocks += 1; - } else { - ShuffleBlockBatchId blockBatchId = (ShuffleBlockBatchId)blocks[i]; - totalBlocks += (blockBatchId.endReduceId() - blockBatchId.startReduceId()); - } + + /** + * Submits n non blocking fetch offsets to get needed offsets for n blocks. + */ + private void submitFetchOffsets(UcpEndpoint endpoint, BlockId[] blockIds, + RegisteredMemory offsetMemory, + long[] dataAddresses) { + DriverMetadata driverMetadata = workerWrapper.fetchDriverMetadataBuffer(shuffleId); + long offset = 0; + int startReduceId; + long size; + + for (int i = 0; i < blockIds.length; i++) { + BlockId blockId = blockIds[i]; + int mapIdpartition; + + if (blockId instanceof ShuffleBlockId) { + ShuffleBlockId shuffleBlockId = (ShuffleBlockId) blockId; + mapIdpartition = mapId2PartitionId.get(shuffleBlockId.mapId()); + size = 2L * UnsafeUtils.LONG_SIZE; + startReduceId = shuffleBlockId.reduceId(); + } else { + ShuffleBlockBatchId shuffleBlockBatchId = (ShuffleBlockBatchId) blockId; + mapIdpartition = mapId2PartitionId.get(shuffleBlockBatchId.mapId()); + size = (shuffleBlockBatchId.endReduceId() - shuffleBlockBatchId.startReduceId()) + * 2L * UnsafeUtils.LONG_SIZE; + startReduceId = shuffleBlockBatchId.startReduceId(); + } + + long offsetAddress = driverMetadata.offsetAddress(mapIdpartition); + dataAddresses[i] = driverMetadata.dataAddress(mapIdpartition); + + offsetRkeysCache.computeIfAbsent(mapIdpartition, mapId -> + endpoint.unpackRemoteKey(driverMetadata.offsetRkey(mapIdpartition))); + + dataRkeysCache.computeIfAbsent(mapIdpartition, mapId -> + endpoint.unpackRemoteKey(driverMetadata.dataRkey(mapIdpartition))); + + endpoint.getNonBlockingImplicit( + offsetAddress + startReduceId * UnsafeUtils.LONG_SIZE, + offsetRkeysCache.get(mapIdpartition), + UcxUtils.getAddress(offsetMemory.getBuffer()) + offset, + size); + + offset += size; + } + } + + @Override + public void fetchBlocks(String host, int port, String execId, String[] blockIds, BlockFetchingListener listener, + DownloadFileManager downloadFileManager) { + long startTime = System.currentTimeMillis(); + BlockManagerId blockManagerId = BlockManagerId.apply(execId, host, port, Option.empty()); + UcpEndpoint endpoint = workerWrapper.getConnection(blockManagerId); + long[] dataAddresses = new long[blockIds.length]; + int totalBlocks = 0; + + BlockId[] blocks = new BlockId[blockIds.length]; + + for (int i = 0; i < blockIds.length; i++) { + blocks[i] = BlockId.apply(blockIds[i]); + if (blocks[i] instanceof ShuffleBlockId) { + totalBlocks += 1; + } else { + ShuffleBlockBatchId blockBatchId = (ShuffleBlockBatchId) blocks[i]; + totalBlocks += (blockBatchId.endReduceId() - blockBatchId.startReduceId()); + } + } + + RegisteredMemory offsetMemory = ((UcxShuffleManager) SparkEnv.get().shuffleManager()) + .ucxNode().getMemoryPool().get(totalBlocks * 2 * UnsafeUtils.LONG_SIZE); + // Submits N implicit get requests without callback + submitFetchOffsets(endpoint, blocks, offsetMemory, dataAddresses); + + // flush guarantees that all that requests completes when callback is called. + endpoint.flushNonBlocking( + new OnOffsetsFetchCallback(blocks, endpoint, listener, offsetMemory, + dataAddresses, dataRkeysCache, mapId2PartitionId)); + + shuffleReadMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime); } - RegisteredMemory offsetMemory = ((UcxShuffleManager)SparkEnv.get().shuffleManager()) - .ucxNode().getMemoryPool().get(totalBlocks * 2 * UnsafeUtils.LONG_SIZE); - // Submits N implicit get requests without callback - submitFetchOffsets(endpoint, blocks, offsetMemory, dataAddresses); - - // flush guarantees that all that requests completes when callback is called. - // TODO: fix https://github.com/openucx/ucx/issues/4267 and use endpoint flush. - workerWrapper.worker().flushNonBlocking( - new OnOffsetsFetchCallback(blocks, endpoint, listener, offsetMemory, - dataAddresses, dataRkeysCache, mapId2PartitionId)); - - shuffleReadMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime); - } - - @Override - public void close() { - offsetRkeysCache.values().forEach(UcpRemoteKey::close); - dataRkeysCache.values().forEach(UcpRemoteKey::close); - logger.info("Shuffle read metrics, fetch wait time: {}ms", shuffleReadMetrics.fetchWaitTime()); - } + @Override + public void close() { + offsetRkeysCache.values().forEach(UcpRemoteKey::close); + dataRkeysCache.values().forEach(UcpRemoteKey::close); + logger.info("Shuffle read metrics, fetch wait time: {}ms", shuffleReadMetrics.fetchWaitTime()); + } } diff --git a/src/main/java/org/apache/spark/shuffle/ucx/rpc/RpcConnectionCallback.java b/src/main/java/org/apache/spark/shuffle/ucx/rpc/RpcConnectionCallback.java index a81570e3..b0d885f2 100755 --- a/src/main/java/org/apache/spark/shuffle/ucx/rpc/RpcConnectionCallback.java +++ b/src/main/java/org/apache/spark/shuffle/ucx/rpc/RpcConnectionCallback.java @@ -28,72 +28,72 @@ * introduce cluster to connected executor. */ public class RpcConnectionCallback extends UcxCallback { - private static final Logger logger = LoggerFactory.getLogger(RpcConnectionCallback.class); - private final ByteBuffer metadataBuffer; - private final boolean isDriver; - private final UcxNode ucxNode; - private static final ConcurrentMap rpcConnections = - UcxNode.getRpcConnections(); - private static final ConcurrentMap workerAdresses = - UcxNode.getWorkerAddresses(); + private static final Logger logger = LoggerFactory.getLogger(RpcConnectionCallback.class); + private final ByteBuffer metadataBuffer; + private final boolean isDriver; + private final UcxNode ucxNode; + private static final ConcurrentMap rpcConnections = + UcxNode.getRpcConnections(); + private static final ConcurrentMap workerAdresses = + UcxNode.getWorkerAddresses(); - RpcConnectionCallback(ByteBuffer metadataBuffer, boolean isDriver, UcxNode ucxNode) { - this.metadataBuffer = metadataBuffer; - this.isDriver = isDriver; - this.ucxNode = ucxNode; - } + RpcConnectionCallback(ByteBuffer metadataBuffer, boolean isDriver, UcxNode ucxNode) { + this.metadataBuffer = metadataBuffer; + this.isDriver = isDriver; + this.ucxNode = ucxNode; + } - @Override - public void onSuccess(UcpRequest request) { - int workerAddressSize = metadataBuffer.getInt(); - ByteBuffer workerAddress = Platform.allocateDirectBuffer(workerAddressSize); + @Override + public void onSuccess(UcpRequest request) { + int workerAddressSize = metadataBuffer.getInt(); + ByteBuffer workerAddress = Platform.allocateDirectBuffer(workerAddressSize); - // Copy worker address from metadata buffer to separate buffer. - final ByteBuffer metadataView = metadataBuffer.duplicate(); - metadataView.limit(metadataView.position() + workerAddressSize); - workerAddress.put(metadataView); - metadataBuffer.position(metadataBuffer.position() + workerAddressSize); + // Copy worker address from metadata buffer to separate buffer. + final ByteBuffer metadataView = metadataBuffer.duplicate(); + metadataView.limit(metadataView.position() + workerAddressSize); + workerAddress.put(metadataView); + metadataBuffer.position(metadataBuffer.position() + workerAddressSize); - BlockManagerId blockManagerId; - try { - blockManagerId = SerializableBlockManagerID - .deserializeBlockManagerID(metadataBuffer); - } catch (IOException e) { - String errorMsg = String.format("Failed to deserialize BlockManagerId: %s", e.getMessage()); - throw new UcxException(errorMsg); - } - logger.debug("Received RPC message from {}", blockManagerId); - UcpWorker globalWorker = ucxNode.getGlobalWorker(); + BlockManagerId blockManagerId; + try { + blockManagerId = SerializableBlockManagerID + .deserializeBlockManagerID(metadataBuffer); + } catch (IOException e) { + String errorMsg = String.format("Failed to deserialize BlockManagerId: %s", e.getMessage()); + throw new UcxException(errorMsg); + } + logger.debug("Received RPC message from {}", blockManagerId); + UcpWorker globalWorker = ucxNode.getGlobalWorker(); - workerAddress.clear(); + workerAddress.clear(); - if (isDriver) { - metadataBuffer.clear(); - UcpEndpoint newConnection = globalWorker.newEndpoint( - new UcpEndpointParams().setPeerErrorHandlingMode() - .setUcpAddress(workerAddress)); - // For each existing connection - rpcConnections.forEach((connection, connectionMetadata) -> { - // send address of joined worker to already connected workers - connection.sendTaggedNonBlocking(metadataBuffer, null); - // introduce other workers to joined worker - newConnection.sendTaggedNonBlocking(connectionMetadata, null); - }); + if (isDriver) { + metadataBuffer.clear(); + UcpEndpoint newConnection = globalWorker.newEndpoint( + new UcpEndpointParams().setPeerErrorHandlingMode() + .setUcpAddress(workerAddress)); + // For each existing connection + rpcConnections.forEach((connection, connectionMetadata) -> { + // send address of joined worker to already connected workers + connection.sendTaggedNonBlocking(metadataBuffer, null); + // introduce other workers to joined worker + newConnection.sendTaggedNonBlocking(connectionMetadata, null); + }); - rpcConnections.put(newConnection, metadataBuffer); - } - workerAdresses.put(blockManagerId, workerAddress); - synchronized (workerAdresses) { - workerAdresses.notifyAll(); + rpcConnections.put(newConnection, metadataBuffer); + } + workerAdresses.put(blockManagerId, workerAddress); + synchronized (workerAdresses) { + workerAdresses.notifyAll(); + } } - } - @Override - public void onError(int ucsStatus, String errorMsg) { - // UCS_ERR_CANCELED = -16, - if (ucsStatus != -16) { - logger.error("Request error: {}", errorMsg); - throw new UcxException(errorMsg); + @Override + public void onError(int ucsStatus, String errorMsg) { + // UCS_ERR_CANCELED = -16, + if (ucsStatus != -16) { + logger.error("Request error: {}", errorMsg); + throw new UcxException(errorMsg); + } } - } } diff --git a/src/main/java/org/apache/spark/shuffle/ucx/rpc/SerializableBlockManagerID.java b/src/main/java/org/apache/spark/shuffle/ucx/rpc/SerializableBlockManagerID.java index dfc5af7a..5560e418 100755 --- a/src/main/java/org/apache/spark/shuffle/ucx/rpc/SerializableBlockManagerID.java +++ b/src/main/java/org/apache/spark/shuffle/ucx/rpc/SerializableBlockManagerID.java @@ -14,19 +14,19 @@ */ public class SerializableBlockManagerID { - public static void serializeBlockManagerID(BlockManagerId blockManagerId, - ByteBuffer metadataBuffer) throws IOException { - ObjectOutputStream oos = new ObjectOutputStream( - new ByteBufferOutputStream(metadataBuffer)); - blockManagerId.writeExternal(oos); - oos.close(); - } + public static void serializeBlockManagerID(BlockManagerId blockManagerId, + ByteBuffer metadataBuffer) throws IOException { + ObjectOutputStream oos = new ObjectOutputStream( + new ByteBufferOutputStream(metadataBuffer)); + blockManagerId.writeExternal(oos); + oos.close(); + } - static BlockManagerId deserializeBlockManagerID(ByteBuffer metadataBuffer) throws IOException { - ObjectInputStream ois = - new ObjectInputStream(new ByteBufferInputStream(metadataBuffer)); - BlockManagerId blockManagerId = BlockManagerId.apply(ois); - ois.close(); - return blockManagerId; - } + static BlockManagerId deserializeBlockManagerID(ByteBuffer metadataBuffer) throws IOException { + ObjectInputStream ois = + new ObjectInputStream(new ByteBufferInputStream(metadataBuffer)); + BlockManagerId blockManagerId = BlockManagerId.apply(ois); + ois.close(); + return blockManagerId; + } } diff --git a/src/main/java/org/apache/spark/shuffle/ucx/rpc/UcxListenerThread.java b/src/main/java/org/apache/spark/shuffle/ucx/rpc/UcxListenerThread.java index ce2ecff5..e2252d06 100755 --- a/src/main/java/org/apache/spark/shuffle/ucx/rpc/UcxListenerThread.java +++ b/src/main/java/org/apache/spark/shuffle/ucx/rpc/UcxListenerThread.java @@ -17,47 +17,47 @@ * Thread for progressing global worker for connection establishment and RPC exchange. */ public class UcxListenerThread extends Thread implements Runnable { - private static final Logger logger = LoggerFactory.getLogger(UcxListenerThread.class); - private final UcxNode ucxNode; - private final boolean isDriver; - private final UcpWorker globalWorker; + private static final Logger logger = LoggerFactory.getLogger(UcxListenerThread.class); + private final UcxNode ucxNode; + private final boolean isDriver; + private final UcpWorker globalWorker; - public UcxListenerThread(UcxNode ucxNode, boolean isDriver) { - this.ucxNode = ucxNode; - this.isDriver = isDriver; - this.globalWorker = ucxNode.getGlobalWorker(); - setDaemon(true); - setName("UcxListenerThread"); - } + public UcxListenerThread(UcxNode ucxNode, boolean isDriver) { + this.ucxNode = ucxNode; + this.isDriver = isDriver; + this.globalWorker = ucxNode.getGlobalWorker(); + setDaemon(true); + setName("UcxListenerThread"); + } - /** - * 2. Both Driver and Executor. Accept Recv request. - * If on driver broadcast it to other executors. On executor just save worker addresses. - */ - private UcpRequest recvRequest() { - ByteBuffer metadataBuffer = Platform.allocateDirectBuffer( - ucxNode.getConf().metadataRPCBufferSize()); - RpcConnectionCallback callback = new RpcConnectionCallback(metadataBuffer, isDriver, ucxNode); - return globalWorker.recvTaggedNonBlocking(metadataBuffer, callback); - } + /** + * 2. Both Driver and Executor. Accept Recv request. + * If on driver broadcast it to other executors. On executor just save worker addresses. + */ + private UcpRequest recvRequest() { + ByteBuffer metadataBuffer = Platform.allocateDirectBuffer( + ucxNode.getConf().metadataRPCBufferSize()); + RpcConnectionCallback callback = new RpcConnectionCallback(metadataBuffer, isDriver, ucxNode); + return globalWorker.recvTaggedNonBlocking(metadataBuffer, callback); + } - @Override - public void run() { - UcpRequest recv = recvRequest(); - while (!isInterrupted()) { - if (recv.isCompleted()) { - // Process 1 recv request at a time. - recv = recvRequest(); - } - try { - if (globalWorker.progress() == 0) { - globalWorker.waitForEvents(); + @Override + public void run() { + UcpRequest recv = recvRequest(); + while (!isInterrupted()) { + if (recv.isCompleted()) { + // Process 1 recv request at a time. + recv = recvRequest(); + } + try { + if (globalWorker.progress() == 0) { + globalWorker.waitForEvents(); + } + } catch (Exception e) { + logger.error(e.getLocalizedMessage()); + interrupt(); + } } - } catch (Exception e) { - logger.error(e.getLocalizedMessage()); - interrupt(); - } + globalWorker.cancelRequest(recv); } - globalWorker.cancelRequest(recv); - } } diff --git a/src/main/java/org/apache/spark/shuffle/ucx/rpc/UcxRemoteMemory.java b/src/main/java/org/apache/spark/shuffle/ucx/rpc/UcxRemoteMemory.java index 08626cee..e48b26f2 100755 --- a/src/main/java/org/apache/spark/shuffle/ucx/rpc/UcxRemoteMemory.java +++ b/src/main/java/org/apache/spark/shuffle/ucx/rpc/UcxRemoteMemory.java @@ -16,39 +16,40 @@ * spark's mechanism to broadcast tasks. */ public class UcxRemoteMemory implements Serializable { - private long address; - private ByteBuffer rkeyBuffer; - - public UcxRemoteMemory(long address, ByteBuffer rkeyBuffer) { - this.address = address; - this.rkeyBuffer = rkeyBuffer; - } - - public UcxRemoteMemory() {} - - private void writeObject(ObjectOutputStream out) throws IOException { - out.writeLong(address); - out.writeInt(rkeyBuffer.limit()); - byte[] copy = new byte[rkeyBuffer.limit()]; - rkeyBuffer.clear(); - rkeyBuffer.get(copy); - out.write(copy); - } - - private void readObject(ObjectInputStream in) throws IOException { - this.address = in.readLong(); - int bufferSize = in.readInt(); - byte[] buffer = new byte[bufferSize]; - in.read(buffer, 0, bufferSize); - this.rkeyBuffer = ByteBuffer.allocateDirect(bufferSize).put(buffer); - this.rkeyBuffer.clear(); - } - - public long getAddress() { - return address; - } - - public ByteBuffer getRkeyBuffer() { - return rkeyBuffer; - } + private long address; + private ByteBuffer rkeyBuffer; + + public UcxRemoteMemory(long address, ByteBuffer rkeyBuffer) { + this.address = address; + this.rkeyBuffer = rkeyBuffer; + } + + public UcxRemoteMemory() { + } + + private void writeObject(ObjectOutputStream out) throws IOException { + out.writeLong(address); + out.writeInt(rkeyBuffer.limit()); + byte[] copy = new byte[rkeyBuffer.limit()]; + rkeyBuffer.clear(); + rkeyBuffer.get(copy); + out.write(copy); + } + + private void readObject(ObjectInputStream in) throws IOException { + this.address = in.readLong(); + int bufferSize = in.readInt(); + byte[] buffer = new byte[bufferSize]; + in.read(buffer, 0, bufferSize); + this.rkeyBuffer = ByteBuffer.allocateDirect(bufferSize).put(buffer); + this.rkeyBuffer.clear(); + } + + public long getAddress() { + return address; + } + + public ByteBuffer getRkeyBuffer() { + return rkeyBuffer; + } } diff --git a/src/main/scala/org/apache/spark/shuffle/CommonUcxShuffleBlockResolver.scala b/src/main/scala/org/apache/spark/shuffle/CommonUcxShuffleBlockResolver.scala index 1405fade..e4dd335f 100755 --- a/src/main/scala/org/apache/spark/shuffle/CommonUcxShuffleBlockResolver.scala +++ b/src/main/scala/org/apache/spark/shuffle/CommonUcxShuffleBlockResolver.scala @@ -1,7 +1,7 @@ /* -* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. -* See file LICENSE for terms. -*/ + * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ package org.apache.spark.shuffle import java.io.{File, RandomAccessFile} @@ -14,25 +14,31 @@ import org.openucx.jucx.ucp.{UcpMemMapParams, UcpMemory} import org.apache.spark.shuffle.ucx.UnsafeUtils import org.apache.spark.SparkException -/** - * Mapper entry point for UcxShuffle plugin. Performs memory registration - * of data and index files and publish addresses to driver metadata buffer. - */ -abstract class CommonUcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager) - extends IndexShuffleBlockResolver(ucxShuffleManager.conf) { +/** Mapper entry point for UcxShuffle plugin. Performs memory registration + * of data and index files and publish addresses to driver metadata buffer. + */ +abstract class CommonUcxShuffleBlockResolver( + ucxShuffleManager: CommonUcxShuffleManager +) extends IndexShuffleBlockResolver(ucxShuffleManager.conf) { private lazy val memPool = ucxShuffleManager.ucxNode.getMemoryPool // Keep track of registered memory regions to release them when shuffle not needed - private val fileMappings = new ConcurrentHashMap[ShuffleId, CopyOnWriteArrayList[UcpMemory]].asScala - private val offsetMappings = new ConcurrentHashMap[ShuffleId, CopyOnWriteArrayList[UcpMemory]].asScala - - /** - * Mapper commit protocol extension. Register index and data files and publish all needed - * metadata to driver. - */ - def writeIndexFileAndCommitCommon(shuffleId: ShuffleId, mapId: Int, - lengths: Array[Long], dataTmp: File, - indexBackFile: RandomAccessFile, dataBackFile: RandomAccessFile): Unit = { + private val fileMappings = + new ConcurrentHashMap[ShuffleId, CopyOnWriteArrayList[UcpMemory]].asScala + private val offsetMappings = + new ConcurrentHashMap[ShuffleId, CopyOnWriteArrayList[UcpMemory]].asScala + + /** Mapper commit protocol extension. Register index and data files and publish all needed + * metadata to driver. + */ + def writeIndexFileAndCommitCommon( + shuffleId: ShuffleId, + mapId: Int, + lengths: Array[Long], + dataTmp: File, + indexBackFile: RandomAccessFile, + dataBackFile: RandomAccessFile + ): Unit = { val startTime = System.currentTimeMillis() fileMappings.putIfAbsent(shuffleId, new CopyOnWriteArrayList[UcpMemory]()) @@ -42,19 +48,26 @@ abstract class CommonUcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffle val dataFileChannel = dataBackFile.getChannel // Memory map and register data and index file. - val dataAddress = UnsafeUtils.mmap(dataFileChannel, 0, dataBackFile.length()) - val memMapParams = new UcpMemMapParams().setAddress(dataAddress) + val dataAddress = + UnsafeUtils.mmap(dataFileChannel, 0, dataBackFile.length()) + val memMapParams = new UcpMemMapParams() + .setAddress(dataAddress) .setLength(dataBackFile.length()) if (ucxShuffleManager.ucxShuffleConf.useOdp) { memMapParams.nonBlocking() } - val dataMemory = ucxShuffleManager.ucxNode.getContext.memoryMap(memMapParams) + val dataMemory = + ucxShuffleManager.ucxNode.getContext.memoryMap(memMapParams) fileMappings(shuffleId).add(dataMemory) - assume(indexBackFile.length() == UnsafeUtils.LONG_SIZE * (lengths.length + 1)) + assume( + indexBackFile.length() == UnsafeUtils.LONG_SIZE * (lengths.length + 1) + ) - val offsetAddress = UnsafeUtils.mmap(indexFileChannel, 0, indexBackFile.length()) + val offsetAddress = + UnsafeUtils.mmap(indexFileChannel, 0, indexBackFile.length()) memMapParams.setAddress(offsetAddress).setLength(indexBackFile.length()) - val offsetMemory = ucxShuffleManager.ucxNode.getContext.memoryMap(memMapParams) + val offsetMemory = + ucxShuffleManager.ucxNode.getContext.memoryMap(memMapParams) offsetMappings(shuffleId).add(offsetMemory) dataFileChannel.close() @@ -65,14 +78,19 @@ abstract class CommonUcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffle val fileMemoryRkey = dataMemory.getRemoteKeyBuffer val offsetRkey = offsetMemory.getRemoteKeyBuffer - val metadataRegisteredMemory = memPool.get( - fileMemoryRkey.capacity() + offsetRkey.capacity() + 24) + val metadataRegisteredMemory = + memPool.get(fileMemoryRkey.capacity() + offsetRkey.capacity() + 24) val metadataBuffer = metadataRegisteredMemory.getBuffer.slice() - if (metadataBuffer.remaining() > ucxShuffleManager.ucxShuffleConf.metadataBlockSize) { - throw new SparkException(s"Metadata block size ${metadataBuffer.remaining() / 2} " + - s"is greater then configured ${ucxShuffleManager.ucxShuffleConf.RKEY_SIZE.key}" + - s"(${ucxShuffleManager.ucxShuffleConf.metadataBlockSize}).") + if ( + metadataBuffer + .remaining() > ucxShuffleManager.ucxShuffleConf.metadataBlockSize + ) { + throw new SparkException( + s"Metadata block size ${metadataBuffer.remaining() / 2} " + + s"is greater then configured ${ucxShuffleManager.ucxShuffleConf.RKEY_SIZE.key}" + + s"(${ucxShuffleManager.ucxShuffleConf.metadataBlockSize})." + ) } metadataBuffer.clear() @@ -94,16 +112,23 @@ abstract class CommonUcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffle mapId * ucxShuffleManager.ucxShuffleConf.metadataBlockSize val driverEndpoint = workerWrapper.driverEndpoint - val request = driverEndpoint.putNonBlocking(UcxUtils.getAddress(metadataBuffer), - metadataBuffer.remaining(), driverOffset, driverMetadata.driverRkey, null) + val request = driverEndpoint.putNonBlocking( + UcxUtils.getAddress(metadataBuffer), + metadataBuffer.remaining(), + driverOffset, + driverMetadata.driverRkey, + null + ) workerWrapper.preconnect() // Blocking progress needed to make sure last mapper published data to driver before // reducer starts. workerWrapper.waitRequest(request) memPool.put(metadataRegisteredMemory) - logInfo(s"MapID: $mapId register files + publishing overhead: " + - s"${System.currentTimeMillis() - startTime} ms") + logInfo( + s"MapID: $mapId register files + publishing overhead: " + + s"${System.currentTimeMillis() - startTime} ms" + ) } private def unregisterAndUnmap(mem: UcpMemory): Unit = { @@ -114,10 +139,16 @@ abstract class CommonUcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffle } def removeShuffle(shuffleId: Int): Unit = { - fileMappings.remove(shuffleId).foreach((mappings: CopyOnWriteArrayList[UcpMemory]) => - mappings.asScala.par.foreach(unregisterAndUnmap)) - offsetMappings.remove(shuffleId).foreach((mappings: CopyOnWriteArrayList[UcpMemory]) => - mappings.asScala.par.foreach(unregisterAndUnmap)) + fileMappings + .remove(shuffleId) + .foreach((mappings: CopyOnWriteArrayList[UcpMemory]) => + mappings.asScala.par.foreach(unregisterAndUnmap) + ) + offsetMappings + .remove(shuffleId) + .foreach((mappings: CopyOnWriteArrayList[UcpMemory]) => + mappings.asScala.par.foreach(unregisterAndUnmap) + ) } override def stop(): Unit = { diff --git a/src/main/scala/org/apache/spark/shuffle/CommonUcxShuffleManager.scala b/src/main/scala/org/apache/spark/shuffle/CommonUcxShuffleManager.scala index 05ebad4c..e2a00a77 100755 --- a/src/main/scala/org/apache/spark/shuffle/CommonUcxShuffleManager.scala +++ b/src/main/scala/org/apache/spark/shuffle/CommonUcxShuffleManager.scala @@ -1,7 +1,7 @@ /* -* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. -* See file LICENSE for terms. -*/ + * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ package org.apache.spark.shuffle import java.util.concurrent.ConcurrentHashMap @@ -16,10 +16,10 @@ import org.apache.spark.shuffle.ucx.UcxNode import org.apache.spark.shuffle.ucx.rpc.UcxRemoteMemory import org.apache.spark.unsafe.Platform -/** - * Common part for all spark versions for UcxShuffleManager logic - */ -abstract class CommonUcxShuffleManager(val conf: SparkConf, isDriver: Boolean) extends SortShuffleManager(conf) { +/** Common part for all spark versions for UcxShuffleManager logic + */ +abstract class CommonUcxShuffleManager(val conf: SparkConf, isDriver: Boolean) + extends SortShuffleManager(conf) { type ShuffleId = Int type MapId = Int val ucxShuffleConf = new UcxShuffleConf(conf) @@ -36,9 +36,11 @@ abstract class CommonUcxShuffleManager(val conf: SparkConf, isDriver: Boolean) e startUcxNodeIfMissing() } - protected def registerShuffleCommon[K, V, C](baseHandle: BaseShuffleHandle[K,V,C], - shuffleId: ShuffleId, - numMaps: Int): ShuffleHandle = { + protected def registerShuffleCommon[K, V, C]( + baseHandle: BaseShuffleHandle[K, V, C], + shuffleId: ShuffleId, + numMaps: Int + ): ShuffleHandle = { // Register metadata buffer where each map will publish it's index and data file metadata val metadataBufferSize = numMaps * ucxShuffleConf.metadataBlockSize val metadataBuffer = Platform.allocateDirectBuffer(metadataBufferSize.toInt) @@ -46,24 +48,25 @@ abstract class CommonUcxShuffleManager(val conf: SparkConf, isDriver: Boolean) e val metadataMemory = ucxNode.getContext.registerMemory(metadataBuffer) shuffleIdToMetadataBuffer.put(shuffleId, metadataMemory) - val driverMemory = new UcxRemoteMemory(metadataMemory.getAddress, - metadataMemory.getRemoteKeyBuffer) + val driverMemory = new UcxRemoteMemory( + metadataMemory.getAddress, + metadataMemory.getRemoteKeyBuffer + ) - val handle = new UcxShuffleHandle(shuffleId, driverMemory, numMaps, baseHandle) + val handle = + new UcxShuffleHandle(shuffleId, driverMemory, numMaps, baseHandle) shuffleIdToHandle.putIfAbsent(shuffleId, handle) handle } - /** - * Mapping between shuffle and metadata buffer, to deregister it when shuffle not needed. - */ + /** Mapping between shuffle and metadata buffer, to deregister it when shuffle not needed. + */ protected val shuffleIdToMetadataBuffer: mutable.Map[ShuffleId, UcpMemory] = new ConcurrentHashMap[ShuffleId, UcpMemory]().asScala - /** - * Atomically starts UcxNode singleton - one for all shuffle threads. - */ + /** Atomically starts UcxNode singleton - one for all shuffle threads. + */ def startUcxNodeIfMissing(): Unit = if (ucxNode == null) { synchronized { if (ucxNode == null) { @@ -74,13 +77,14 @@ abstract class CommonUcxShuffleManager(val conf: SparkConf, isDriver: Boolean) e override def unregisterShuffle(shuffleId: Int): Boolean = { shuffleIdToMetadataBuffer.remove(shuffleId).foreach(_.deregister()) - shuffleBlockResolver.asInstanceOf[CommonUcxShuffleBlockResolver].removeShuffle(shuffleId) + shuffleBlockResolver + .asInstanceOf[CommonUcxShuffleBlockResolver] + .removeShuffle(shuffleId) super.unregisterShuffle(shuffleId) } - /** - * Called on both driver and executors to finally cleanup resources. - */ + /** Called on both driver and executors to finally cleanup resources. + */ override def stop(): Unit = synchronized { logInfo("Stopping shuffle manager") shuffleIdToHandle.keys.foreach(unregisterShuffle) @@ -94,11 +98,12 @@ abstract class CommonUcxShuffleManager(val conf: SparkConf, isDriver: Boolean) e } -/** - * Spark shuffle handles extensions, broadcasted by TCP to executors. - * Added metadataBufferOnDriver field, that contains address and rkey of driver metadata buffer. - */ -class UcxShuffleHandle[K, V, C](override val shuffleId: Int, - val metadataBufferOnDriver: UcxRemoteMemory, - val numMaps: Int, - val baseHandle: BaseShuffleHandle[K,V,C]) extends ShuffleHandle(shuffleId) +/** Spark shuffle handles extensions, broadcasted by TCP to executors. + * Added metadataBufferOnDriver field, that contains address and rkey of driver metadata buffer. + */ +class UcxShuffleHandle[K, V, C]( + override val shuffleId: Int, + val metadataBufferOnDriver: UcxRemoteMemory, + val numMaps: Int, + val baseHandle: BaseShuffleHandle[K, V, C] +) extends ShuffleHandle(shuffleId) diff --git a/src/main/scala/org/apache/spark/shuffle/UcxShuffleConf.scala b/src/main/scala/org/apache/spark/shuffle/UcxShuffleConf.scala index 2fc6c6c2..8b7bcb9e 100755 --- a/src/main/scala/org/apache/spark/shuffle/UcxShuffleConf.scala +++ b/src/main/scala/org/apache/spark/shuffle/UcxShuffleConf.scala @@ -1,7 +1,7 @@ /* -* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. -* See file LICENSE for terms. -*/ + * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ package org.apache.spark.shuffle import scala.collection.JavaConverters._ @@ -11,33 +11,34 @@ import org.apache.spark.internal.config.{ConfigBuilder, ConfigEntry} import org.apache.spark.network.util.ByteUnit import org.apache.spark.util.Utils -/** - * Plugin configuration properties. - */ +/** Plugin configuration properties. + */ class UcxShuffleConf(conf: SparkConf) extends SparkConf { private def getUcxConf(name: String) = s"spark.shuffle.ucx.$name" lazy val getNumProcesses: Int = getInt("spark.executor.instances", 1) - lazy val coresPerProcess: Int = getInt("spark.executor.cores", - Runtime.getRuntime.availableProcessors()) + lazy val coresPerProcess: Int = + getInt("spark.executor.cores", Runtime.getRuntime.availableProcessors()) - lazy val driverHost: String = conf.get(getUcxConf("driver.host"), - conf.get("spark.driver.host", "0.0.0.0")) + lazy val driverHost: String = conf.get( + getUcxConf("driver.host"), + conf.get("spark.driver.host", "0.0.0.0") + ) lazy val driverPort: Int = conf.getInt(getUcxConf("driver.port"), 55443) // Metadata lazy val RKEY_SIZE: ConfigEntry[Long] = - ConfigBuilder(getUcxConf("rkeySize")) - .doc("Maximum size of rKeyBuffer") - .bytesConf(ByteUnit.BYTE) - .createWithDefault(150) + ConfigBuilder(getUcxConf("rkeySize")) + .doc("Maximum size of rKeyBuffer") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(150) // For metadata we publish index file + data file rkeys - lazy val metadataBlockSize: Long = 2 * conf.getSizeAsBytes(RKEY_SIZE.key, - RKEY_SIZE.defaultValueString) + lazy val metadataBlockSize: Long = + 2 * conf.getSizeAsBytes(RKEY_SIZE.key, RKEY_SIZE.defaultValueString) private lazy val METADATA_RPC_BUFFER_SIZE = ConfigBuilder(getUcxConf("rpc.metadata.bufferSize")) @@ -45,46 +46,71 @@ class UcxShuffleConf(conf: SparkConf) extends SparkConf { .bytesConf(ByteUnit.BYTE) .createWithDefault(4096) - lazy val metadataRPCBufferSize: Int = conf.getSizeAsBytes(METADATA_RPC_BUFFER_SIZE.key, - METADATA_RPC_BUFFER_SIZE.defaultValueString).toInt + lazy val metadataRPCBufferSize: Int = conf + .getSizeAsBytes( + METADATA_RPC_BUFFER_SIZE.key, + METADATA_RPC_BUFFER_SIZE.defaultValueString + ) + .toInt // Memory Pool private lazy val PREALLOCATE_BUFFERS = - ConfigBuilder(getUcxConf("memory.preAllocateBuffers")) - .doc("Comma separated list of buffer size : buffer count pairs to preallocate in memory pool. E.g. 4k:1000,16k:500") - .stringConf.createWithDefault("") - - lazy val preallocateBuffersMap: java.util.Map[java.lang.Integer, java.lang.Integer] = { - conf.get(PREALLOCATE_BUFFERS).split(",").withFilter(s => !s.isEmpty) - .map(entry => entry.split(":") match { - case Array(bufferSize, bufferCount) => - (int2Integer(Utils.byteStringAsBytes(bufferSize.trim).toInt), - int2Integer(bufferCount.toInt)) - }).toMap.asJava + ConfigBuilder(getUcxConf("memory.preAllocateBuffers")) + .doc( + "Comma separated list of buffer size : buffer count pairs to preallocate in memory pool. E.g. 4k:1000,16k:500" + ) + .stringConf + .createWithDefault("") + + lazy val preallocateBuffersMap + : java.util.Map[java.lang.Integer, java.lang.Integer] = { + conf + .get(PREALLOCATE_BUFFERS) + .split(",") + .withFilter(s => !s.isEmpty) + .map(entry => + entry.split(":") match { + case Array(bufferSize, bufferCount) => + ( + int2Integer(Utils.byteStringAsBytes(bufferSize.trim).toInt), + int2Integer(bufferCount.toInt) + ) + } + ) + .toMap + .asJava } - private lazy val MIN_BUFFER_SIZE = ConfigBuilder(getUcxConf("memory.minBufferSize")) - .doc("Minimal buffer size in memory pool.") - .bytesConf(ByteUnit.BYTE) - .createWithDefault(1024) + private lazy val MIN_BUFFER_SIZE = + ConfigBuilder(getUcxConf("memory.minBufferSize")) + .doc("Minimal buffer size in memory pool.") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(1024) - lazy val minBufferSize: Long = conf.getSizeAsBytes(MIN_BUFFER_SIZE.key, - MIN_BUFFER_SIZE.defaultValueString) + lazy val minBufferSize: Long = + conf.getSizeAsBytes(MIN_BUFFER_SIZE.key, MIN_BUFFER_SIZE.defaultValueString) private lazy val MIN_REGISTRATION_SIZE = ConfigBuilder(getUcxConf("memory.minAllocationSize")) - .doc("Minimal memory registration size in memory pool.") - .bytesConf(ByteUnit.MiB) - .createWithDefault(4) - - lazy val minRegistrationSize: Int = conf.getSizeAsBytes(MIN_REGISTRATION_SIZE.key, - MIN_REGISTRATION_SIZE.defaultValueString).toInt - - private lazy val PREREGISTER_MEMORY = ConfigBuilder(getUcxConf("memory.preregister")) - .doc("Whether to do ucp mem map for allocated memory in memory pool") - .booleanConf.createWithDefault(true) - - lazy val preregisterMemory: Boolean = conf.getBoolean(PREREGISTER_MEMORY.key, PREREGISTER_MEMORY.defaultValue.get) + .doc("Minimal memory registration size in memory pool.") + .bytesConf(ByteUnit.MiB) + .createWithDefault(4) + + lazy val minRegistrationSize: Int = conf + .getSizeAsBytes( + MIN_REGISTRATION_SIZE.key, + MIN_REGISTRATION_SIZE.defaultValueString + ) + .toInt + + private lazy val PREREGISTER_MEMORY = + ConfigBuilder(getUcxConf("memory.preregister")) + .doc("Whether to do ucp mem map for allocated memory in memory pool") + .booleanConf + .createWithDefault(true) + + lazy val preregisterMemory: Boolean = + conf.getBoolean(PREREGISTER_MEMORY.key, PREREGISTER_MEMORY.defaultValue.get) lazy val useOdp: Boolean = conf.getBoolean(getUcxConf("memory.useOdp"), false) } diff --git a/src/main/scala/org/apache/spark/shuffle/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/UcxWorkerWrapper.scala index 0986a500..ab4ce95f 100755 --- a/src/main/scala/org/apache/spark/shuffle/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/UcxWorkerWrapper.scala @@ -1,7 +1,7 @@ /* -* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. -* See file LICENSE for terms. -*/ + * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ package org.apache.spark.shuffle import java.io.Closeable @@ -13,19 +13,28 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import org.openucx.jucx.UcxException -import org.openucx.jucx.ucp.{UcpEndpoint, UcpEndpointParams, UcpRemoteKey, UcpRequest, UcpWorker} +import org.openucx.jucx.ucp.{ + UcpEndpoint, + UcpEndpointParams, + UcpRemoteKey, + UcpRequest, + UcpWorker +} import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.shuffle.ucx.{UcxNode, UnsafeUtils} import org.apache.spark.storage.BlockManagerId import org.apache.spark.unsafe.Platform -/** - * Driver metadata buffer information that holds unpacked RkeyBuffer for this WorkerWrapper - * and fetched buffer itself. - */ -case class DriverMetadata(address: Long, driverRkey: UcpRemoteKey, length: Int, - var data: ByteBuffer) { +/** Driver metadata buffer information that holds unpacked RkeyBuffer for this WorkerWrapper + * and fetched buffer itself. + */ +case class DriverMetadata( + address: Long, + driverRkey: UcpRemoteKey, + length: Int, + var data: ByteBuffer +) { // Driver metadata is an array of blocks: // | mapId0 | mapId1 | mapId2 | mapId3 | mapId4 | mapId5 | // Each block in driver metadata has next layout: @@ -64,64 +73,68 @@ case class DriverMetadata(address: Long, driverRkey: UcpRemoteKey, length: Int, } } -/** - * Worker per thread wrapper, that maintains connection and progress logic. - */ -class UcxWorkerWrapper(val worker: UcpWorker, val conf: UcxShuffleConf, val id: Int) - extends Closeable with Logging { +/** Worker per thread wrapper, that maintains connection and progress logic. + */ +class UcxWorkerWrapper( + val worker: UcpWorker, + val conf: UcxShuffleConf, + val id: Int +) extends Closeable + with Logging { import UcxWorkerWrapper._ - private final val driverSocketAddress = new InetSocketAddress(conf.driverHost, conf.driverPort) - private final val endpointParams = new UcpEndpointParams().setSocketAddress(driverSocketAddress) + private final val driverSocketAddress = + new InetSocketAddress(conf.driverHost, conf.driverPort) + private final val endpointParams = new UcpEndpointParams() + .setSocketAddress(driverSocketAddress) .setPeerErrorHandlingMode() val driverEndpoint: UcpEndpoint = worker.newEndpoint(endpointParams) private final val connections = mutable.Map.empty[BlockManagerId, UcpEndpoint] - private final val driverMetadata = mutable.Map.empty[ShuffleId, DriverMetadata] + private final val driverMetadata = + mutable.Map.empty[ShuffleId, DriverMetadata] override def close(): Unit = { - driverMetadata.values.foreach{ + driverMetadata.values.foreach { case DriverMetadata(address, rkey, length, data) => rkey.close() } driverMetadata.clear() driverEndpoint.close() - connections.foreach{ - case (_, endpoint) => endpoint.close() + connections.foreach { case (_, endpoint) => + endpoint.close() } connections.clear() worker.close() driverMetadataBuffer.clear() } - /** - * Blocking progress single request until it's not completed. - */ + /** Blocking progress single request until it's not completed. + */ def waitRequest(request: UcpRequest): Unit = { val startTime = System.currentTimeMillis() worker.progressRequest(request) - logDebug(s"Request completed in ${System.currentTimeMillis() - startTime} ms") + logDebug( + s"Request completed in ${System.currentTimeMillis() - startTime} ms" + ) } - /** - * Blocking progress while result queue is empty. - */ + /** Blocking progress while result queue is empty. + */ def fillQueueWithBlocks(queue: LinkedBlockingQueue[_]): Unit = { while (queue.isEmpty) { progress() } } - /** - * The only place for worker progress - */ + /** The only place for worker progress + */ private def progress(): Int = { worker.progress() } - /** - * Establish connections to known instances. - */ + /** Establish connections to known instances. + */ def preconnect(): Unit = { UcxNode.getWorkerAddresses.keySet().asScala.foreach(getConnection) } @@ -136,61 +149,74 @@ class UcxWorkerWrapper(val worker: UcpWorker, val conf: UcxShuffleConf, val id: while (workerAddresses.get(blockManagerId) == null) { workerAddresses.wait(timeout) if (System.currentTimeMillis() - startTime > timeout) { - throw new UcxException(s"Didn't get worker address for $blockManagerId during $timeout") + throw new UcxException( + s"Didn't get worker address for $blockManagerId during $timeout" + ) } } } } - connections.getOrElseUpdate(blockManagerId, { - logInfo(s"Worker $id connecting to $blockManagerId") - val endpointParams = new UcpEndpointParams() - .setPeerErrorHandlingMode() - .setUcpAddress(workerAddresses.get(blockManagerId)) - worker.newEndpoint(endpointParams) - }) + connections.getOrElseUpdate( + blockManagerId, { + logInfo(s"Worker $id connecting to $blockManagerId") + val endpointParams = new UcpEndpointParams() + .setPeerErrorHandlingMode() + .setUcpAddress(workerAddresses.get(blockManagerId)) + worker.newEndpoint(endpointParams) + } + ) } - /** - * Unpacks driver metadata RkeyBuffer for this worker. - * Needed to perform PUT operation to publish map output info. - */ + /** Unpacks driver metadata RkeyBuffer for this worker. + * Needed to perform PUT operation to publish map output info. + */ def getDriverMetadata(shuffleId: ShuffleId): DriverMetadata = { - driverMetadata.getOrElseUpdate(shuffleId, { - val ucxShuffleHandle = SparkEnv.get.shuffleManager.asInstanceOf[CommonUcxShuffleManager] - .shuffleIdToHandle(shuffleId) - val (address, length, rkey): (Long, Int, ByteBuffer) = (ucxShuffleHandle.metadataBufferOnDriver.getAddress, - ucxShuffleHandle.numMaps * conf.metadataBlockSize.toInt, - ucxShuffleHandle.metadataBufferOnDriver.getRkeyBuffer) - - rkey.clear() - val unpackedRkey = driverEndpoint.unpackRemoteKey(rkey) - DriverMetadata(address, unpackedRkey, length, null) - }) + driverMetadata.getOrElseUpdate( + shuffleId, { + val ucxShuffleHandle = SparkEnv.get.shuffleManager + .asInstanceOf[CommonUcxShuffleManager] + .shuffleIdToHandle(shuffleId) + val (address, length, rkey): (Long, Int, ByteBuffer) = ( + ucxShuffleHandle.metadataBufferOnDriver.getAddress, + ucxShuffleHandle.numMaps * conf.metadataBlockSize.toInt, + ucxShuffleHandle.metadataBufferOnDriver.getRkeyBuffer + ) + + rkey.clear() + val unpackedRkey = driverEndpoint.unpackRemoteKey(rkey) + DriverMetadata(address, unpackedRkey, length, null) + } + ) } - /** - * Fetches using ucp_get metadata buffer from driver, with all needed information - * for offset and data addresses and keys. - */ + /** Fetches using ucp_get metadata buffer from driver, with all needed information + * for offset and data addresses and keys. + */ def fetchDriverMetadataBuffer(shuffleId: ShuffleId): DriverMetadata = { - val handle = SparkEnv.get.shuffleManager.asInstanceOf[CommonUcxShuffleManager] + val handle = SparkEnv.get.shuffleManager + .asInstanceOf[CommonUcxShuffleManager] .shuffleIdToHandle(shuffleId) val metadata = getDriverMetadata(handle.shuffleId) - UcxWorkerWrapper.driverMetadataBuffer.computeIfAbsent(shuffleId, + UcxWorkerWrapper.driverMetadataBuffer.computeIfAbsent( + shuffleId, (t: ShuffleId) => { val buffer = Platform.allocateDirectBuffer(metadata.length) val request = driverEndpoint.getNonBlocking( - metadata.address, metadata.driverRkey, buffer, null) + metadata.address, + metadata.driverRkey, + buffer, + null + ) waitRequest(request) buffer } ) if (metadata.data == null) { - metadata.data = UcxWorkerWrapper.driverMetadataBuffer.get(shuffleId) + metadata.data = UcxWorkerWrapper.driverMetadataBuffer.get(shuffleId) } metadata } @@ -203,5 +229,9 @@ object UcxWorkerWrapper { val driverMetadataBuffer = new ConcurrentHashMap[ShuffleId, ByteBuffer]() val metadataBlockSize: MapId = - SparkEnv.get.shuffleManager.asInstanceOf[CommonUcxShuffleManager].ucxShuffleConf.metadataBlockSize.toInt + SparkEnv.get.shuffleManager + .asInstanceOf[CommonUcxShuffleManager] + .ucxShuffleConf + .metadataBlockSize + .toInt } diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleBlockResolver.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleBlockResolver.scala index 90084f39..d28db889 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleBlockResolver.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleBlockResolver.scala @@ -1,33 +1,44 @@ /* -* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. -* See file LICENSE for terms. -*/ + * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ package org.apache.spark.shuffle.compat.spark_2_1 import java.io.{File, RandomAccessFile} import org.apache.spark.SparkEnv -import org.apache.spark.shuffle.{CommonUcxShuffleBlockResolver, CommonUcxShuffleManager, IndexShuffleBlockResolver} +import org.apache.spark.shuffle.{ + CommonUcxShuffleBlockResolver, + CommonUcxShuffleManager, + IndexShuffleBlockResolver +} import org.apache.spark.storage.ShuffleIndexBlockId -/** - * Mapper entry point for UcxShuffle plugin. Performs memory registration - * of data and index files and publish addresses to driver metadata buffer. - */ +/** Mapper entry point for UcxShuffle plugin. Performs memory registration + * of data and index files and publish addresses to driver metadata buffer. + */ class UcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager) - extends CommonUcxShuffleBlockResolver(ucxShuffleManager) { + extends CommonUcxShuffleBlockResolver(ucxShuffleManager) { private def getIndexFile(shuffleId: Int, mapId: Int): File = { - SparkEnv.get.blockManager - .diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)) + SparkEnv.get.blockManager.diskBlockManager.getFile( + ShuffleIndexBlockId( + shuffleId, + mapId, + IndexShuffleBlockResolver.NOOP_REDUCE_ID + ) + ) } - /** - * Mapper commit protocol extension. Register index and data files and publish all needed - * metadata to driver. - */ - override def writeIndexFileAndCommit(shuffleId: ShuffleId, mapId: Int, - lengths: Array[Long], dataTmp: File): Unit = { + /** Mapper commit protocol extension. Register index and data files and publish all needed + * metadata to driver. + */ + override def writeIndexFileAndCommit( + shuffleId: ShuffleId, + mapId: Int, + lengths: Array[Long], + dataTmp: File + ): Unit = { super.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp) val dataFile = getDataFile(shuffleId, mapId) val dataBackFile = new RandomAccessFile(dataFile, "rw") @@ -39,6 +50,13 @@ class UcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager) val indexFile = getIndexFile(shuffleId, mapId) val indexBackFile = new RandomAccessFile(indexFile, "rw") - writeIndexFileAndCommitCommon(shuffleId, mapId, lengths, dataTmp, indexBackFile, dataBackFile) + writeIndexFileAndCommitCommon( + shuffleId, + mapId, + lengths, + dataTmp, + indexBackFile, + dataBackFile + ) } } diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleManager.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleManager.scala index 7087e491..b209f763 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleManager.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleManager.scala @@ -1,53 +1,78 @@ /* -* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. -* See file LICENSE for terms. -*/ + * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ package org.apache.spark.shuffle -import org.apache.spark.shuffle.compat.spark_2_1.{UcxShuffleBlockResolver, UcxShuffleReader} +import org.apache.spark.shuffle.compat.spark_2_1.{ + UcxShuffleBlockResolver, + UcxShuffleReader +} import org.apache.spark.util.ShutdownHookManager import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext} -/** - * Main entry point of Ucx shuffle plugin. It extends spark's default SortShufflePlugin - * and injects needed logic in override methods. - */ -class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean) extends CommonUcxShuffleManager(conf, isDriver) { +/** Main entry point of Ucx shuffle plugin. It extends spark's default SortShufflePlugin + * and injects needed logic in override methods. + */ +class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean) + extends CommonUcxShuffleManager(conf, isDriver) { ShutdownHookManager.addShutdownHook(Int.MaxValue - 1)(stop) - /** - * Register a shuffle with the manager and obtain a handle for it to pass to tasks. - * Called on driver and guaranteed by spark that shuffle on executor will start after it. - */ - override def registerShuffle[K, V, C](shuffleId: ShuffleId, - numMaps: Int, - dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + /** Register a shuffle with the manager and obtain a handle for it to pass to tasks. + * Called on driver and guaranteed by spark that shuffle on executor will start after it. + */ + override def registerShuffle[K, V, C]( + shuffleId: ShuffleId, + numMaps: Int, + dependency: ShuffleDependency[K, V, C] + ): ShuffleHandle = { assume(isDriver) - val baseHandle = super.registerShuffle(shuffleId, numMaps, dependency).asInstanceOf[BaseShuffleHandle[K, V, C]] + val baseHandle = super + .registerShuffle(shuffleId, numMaps, dependency) + .asInstanceOf[BaseShuffleHandle[K, V, C]] registerShuffleCommon(baseHandle, shuffleId, numMaps) } - /** - * Mapper callback on executor. Just start UcxNode and use Spark mapper logic. - */ - override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, - context: TaskContext): ShuffleWriter[K, V] = { + /** Mapper callback on executor. Just start UcxNode and use Spark mapper logic. + */ + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Int, + context: TaskContext + ): ShuffleWriter[K, V] = { startUcxNodeIfMissing() - shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K,V,_]]) - super.getWriter(handle.asInstanceOf[UcxShuffleHandle[K,V,_]].baseHandle, mapId, context) + shuffleIdToHandle.putIfAbsent( + handle.shuffleId, + handle.asInstanceOf[UcxShuffleHandle[K, V, _]] + ) + super.getWriter( + handle.asInstanceOf[UcxShuffleHandle[K, V, _]].baseHandle, + mapId, + context + ) } - override val shuffleBlockResolver: UcxShuffleBlockResolver = new UcxShuffleBlockResolver(this) + override val shuffleBlockResolver: UcxShuffleBlockResolver = + new UcxShuffleBlockResolver(this) - /** - * Reducer callback on executor. - */ - override def getReader[K, C](handle: ShuffleHandle, startPartition: Int, - endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { + /** Reducer callback on executor. + */ + override def getReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext + ): ShuffleReader[K, C] = { startUcxNodeIfMissing() - shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K,_,C]]) - new UcxShuffleReader(handle.asInstanceOf[UcxShuffleHandle[K,_,C]], startPartition, - endPartition, context) + shuffleIdToHandle.putIfAbsent( + handle.shuffleId, + handle.asInstanceOf[UcxShuffleHandle[K, _, C]] + ) + new UcxShuffleReader( + handle.asInstanceOf[UcxShuffleHandle[K, _, C]], + startPartition, + endPartition, + context + ) } } - diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleReader.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleReader.scala index 1c7e1511..69b52886 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleReader.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleReader.scala @@ -1,7 +1,7 @@ /* -* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. -* See file LICENSE for terms. -*/ + * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ package org.apache.spark.shuffle.compat.spark_2_1 import java.io.InputStream @@ -10,130 +10,164 @@ import java.util.concurrent.LinkedBlockingQueue import org.apache.spark.internal.{Logging, config} import org.apache.spark.serializer.SerializerManager import org.apache.spark.shuffle.ucx.reducer.compat.spark_2_1.UcxShuffleClient -import org.apache.spark.shuffle.{ShuffleReader, UcxShuffleHandle, UcxShuffleManager} -import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator} +import org.apache.spark.shuffle.{ + ShuffleReader, + UcxShuffleHandle, + UcxShuffleManager +} +import org.apache.spark.storage.{ + BlockId, + BlockManager, + ShuffleBlockFetcherIterator +} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter -import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext} +import org.apache.spark.{ + InterruptibleIterator, + MapOutputTracker, + SparkEnv, + TaskContext +} -/** - * Extension of Spark's shuffe reader with a logic of injection UcxShuffleClient, - * and lazy progress only when result queue is empty. - */ -class UcxShuffleReader[K, C](handle: UcxShuffleHandle[K, _, C], - startPartition: Int, - endPartition: Int, - context: TaskContext, - serializerManager: SerializerManager = SparkEnv.get.serializerManager, - blockManager: BlockManager = SparkEnv.get.blockManager, - mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) - extends ShuffleReader[K, C] with Logging { +/** Extension of Spark's shuffe reader with a logic of injection UcxShuffleClient, + * and lazy progress only when result queue is empty. + */ +class UcxShuffleReader[K, C]( + handle: UcxShuffleHandle[K, _, C], + startPartition: Int, + endPartition: Int, + context: TaskContext, + serializerManager: SerializerManager = SparkEnv.get.serializerManager, + blockManager: BlockManager = SparkEnv.get.blockManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker +) extends ShuffleReader[K, C] + with Logging { - private val dep = handle.baseHandle.dependency + private val dep = handle.baseHandle.dependency - /** Read the combined key-values for this reduce task */ - override def read(): Iterator[Product2[K, C]] = { - val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics() - val workerWrapper = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager] - .ucxNode.getThreadLocalWorker - val shuffleClient = new UcxShuffleClient(shuffleMetrics, workerWrapper) - val wrappedStreams = new ShuffleBlockFetcherIterator( - context, - shuffleClient, - blockManager, - mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, - startPartition, endPartition), - // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, - SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue)) + /** Read the combined key-values for this reduce task */ + override def read(): Iterator[Product2[K, C]] = { + val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics() + val workerWrapper = SparkEnv.get.shuffleManager + .asInstanceOf[UcxShuffleManager] + .ucxNode + .getThreadLocalWorker + val shuffleClient = new UcxShuffleClient(shuffleMetrics, workerWrapper) + val wrappedStreams = new ShuffleBlockFetcherIterator( + context, + shuffleClient, + blockManager, + mapOutputTracker.getMapSizesByExecutorId( + handle.shuffleId, + startPartition, + endPartition + ), + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf + .getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, + SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue) + ) - // Ucx shuffle logic - // Java reflection to get access to private results queue - val queueField = wrappedStreams.getClass.getDeclaredField( - "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results") - queueField.setAccessible(true) - val resultQueue = queueField.get(wrappedStreams).asInstanceOf[LinkedBlockingQueue[_]] + // Ucx shuffle logic + // Java reflection to get access to private results queue + val queueField = wrappedStreams.getClass.getDeclaredField( + "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results" + ) + queueField.setAccessible(true) + val resultQueue = + queueField.get(wrappedStreams).asInstanceOf[LinkedBlockingQueue[_]] - // Do progress if queue is empty before calling next on ShuffleIterator - val ucxWrappedStream = new Iterator[(BlockId, InputStream)] { - override def next(): (BlockId, InputStream) = { - val startTime = System.currentTimeMillis() - workerWrapper.fillQueueWithBlocks(resultQueue) - shuffleMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime) - wrappedStreams.next() - } + // Do progress if queue is empty before calling next on ShuffleIterator + val ucxWrappedStream = new Iterator[(BlockId, InputStream)] { + override def next(): (BlockId, InputStream) = { + val startTime = System.currentTimeMillis() + workerWrapper.fillQueueWithBlocks(resultQueue) + shuffleMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime) + wrappedStreams.next() + } - override def hasNext: Boolean = { - val result = wrappedStreams.hasNext - if (!result) { - shuffleClient.close() - } - result + override def hasNext: Boolean = { + val result = wrappedStreams.hasNext + if (!result) { + shuffleClient.close() } + result } - // End of ucx shuffle logic + } + // End of ucx shuffle logic - val serializerInstance = dep.serializer.newInstance() - val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) => - // Note: the asKeyValueIterator below wraps a key/value iterator inside of a - // NextIterator. The NextIterator makes sure that close() is called on the - // underlying InputStream when all records have been read. - serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator - } + val serializerInstance = dep.serializer.newInstance() + val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) => + // Note: the asKeyValueIterator below wraps a key/value iterator inside of a + // NextIterator. The NextIterator makes sure that close() is called on the + // underlying InputStream when all records have been read. + serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + } - // Update the context task metrics for each record read. - val readMetrics = context.taskMetrics.createTempShuffleReadMetrics() - val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( - recordIter.map { record => - readMetrics.incRecordsRead(1) - record - }, - context.taskMetrics().mergeShuffleReadMetrics()) + // Update the context task metrics for each record read. + val readMetrics = context.taskMetrics.createTempShuffleReadMetrics() + val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( + recordIter.map { record => + readMetrics.incRecordsRead(1) + record + }, + context.taskMetrics().mergeShuffleReadMetrics() + ) - // An interruptible iterator must be used here in order to support task cancellation - val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) + // An interruptible iterator must be used here in order to support task cancellation + val interruptibleIter = + new InterruptibleIterator[(Any, Any)](context, metricIter) - val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { + val aggregatedIter: Iterator[Product2[K, C]] = + if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { // We are reading values that are already combined - val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] - dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) + val combinedKeyValuesIterator = + interruptibleIter.asInstanceOf[Iterator[(K, C)]] + dep.aggregator.get + .combineCombinersByKey(combinedKeyValuesIterator, context) } else { // We don't know the value type, but also don't care -- the dependency *should* // have made sure its compatible w/ this aggregator, which will convert the value // type to the combined type C - val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] + val keyValuesIterator = + interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) } } else { interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] } - // Sort the output if there is a sort ordering defined. - val resultIter = dep.keyOrdering match { - case Some(keyOrd: Ordering[K]) => - // Create an ExternalSorter to sort the data. - val sorter = - new ExternalSorter[K, C, C](context, - ordering = Some(keyOrd), serializer = dep.serializer) - sorter.insertAll(aggregatedIter) - context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) - context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) - context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) - // Use completion callback to stop sorter if task was finished/cancelled. - CompletionIterator[Product2[K, C], - Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) - case None => - aggregatedIter - } + // Sort the output if there is a sort ordering defined. + val resultIter = dep.keyOrdering match { + case Some(keyOrd: Ordering[K]) => + // Create an ExternalSorter to sort the data. + val sorter = + new ExternalSorter[K, C, C]( + context, + ordering = Some(keyOrd), + serializer = dep.serializer + ) + sorter.insertAll(aggregatedIter) + context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) + context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) + // Use completion callback to stop sorter if task was finished/cancelled. + CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]]( + sorter.iterator, + sorter.stop() + ) + case None => + aggregatedIter + } - resultIter match { - case _: InterruptibleIterator[Product2[K, C]] => resultIter - case _ => - // Use another interruptible iterator here to support task cancellation as aggregator - // or(and) sorter may have consumed previous interruptible iterator. - new InterruptibleIterator[Product2[K, C]](context, resultIter) - } + resultIter match { + case _: InterruptibleIterator[Product2[K, C]] => resultIter + case _ => + // Use another interruptible iterator here to support task cancellation as aggregator + // or(and) sorter may have consumed previous interruptible iterator. + new InterruptibleIterator[Product2[K, C]](context, resultIter) } + } } diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleBlockResolver.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleBlockResolver.scala index a47b0af6..6311b5e0 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleBlockResolver.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleBlockResolver.scala @@ -1,33 +1,44 @@ /* -* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. -* See file LICENSE for terms. -*/ + * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ package org.apache.spark.shuffle.compat.spark_2_4 import java.io.{File, RandomAccessFile} import org.apache.spark.SparkEnv -import org.apache.spark.shuffle.{CommonUcxShuffleBlockResolver, CommonUcxShuffleManager, IndexShuffleBlockResolver} +import org.apache.spark.shuffle.{ + CommonUcxShuffleBlockResolver, + CommonUcxShuffleManager, + IndexShuffleBlockResolver +} import org.apache.spark.storage.ShuffleIndexBlockId -/** - * Mapper entry point for UcxShuffle plugin. Performs memory registration - * of data and index files and publish addresses to driver metadata buffer. - */ +/** Mapper entry point for UcxShuffle plugin. Performs memory registration + * of data and index files and publish addresses to driver metadata buffer. + */ class UcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager) - extends CommonUcxShuffleBlockResolver(ucxShuffleManager) { + extends CommonUcxShuffleBlockResolver(ucxShuffleManager) { private def getIndexFile(shuffleId: Int, mapId: Int): File = { - SparkEnv.get.blockManager - .diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)) + SparkEnv.get.blockManager.diskBlockManager.getFile( + ShuffleIndexBlockId( + shuffleId, + mapId, + IndexShuffleBlockResolver.NOOP_REDUCE_ID + ) + ) } - /** - * Mapper commit protocol extension. Register index and data files and publish all needed - * metadata to driver. - */ - override def writeIndexFileAndCommit(shuffleId: ShuffleId, mapId: Int, - lengths: Array[Long], dataTmp: File): Unit = { + /** Mapper commit protocol extension. Register index and data files and publish all needed + * metadata to driver. + */ + override def writeIndexFileAndCommit( + shuffleId: ShuffleId, + mapId: Int, + lengths: Array[Long], + dataTmp: File + ): Unit = { super.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp) val dataFile = getDataFile(shuffleId, mapId) val dataBackFile = new RandomAccessFile(dataFile, "rw") @@ -39,6 +50,13 @@ class UcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager) val indexFile = getIndexFile(shuffleId, mapId) val indexBackFile = new RandomAccessFile(indexFile, "rw") - writeIndexFileAndCommitCommon(shuffleId, mapId, lengths, dataTmp, indexBackFile, dataBackFile) + writeIndexFileAndCommitCommon( + shuffleId, + mapId, + lengths, + dataTmp, + indexBackFile, + dataBackFile + ) } } diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleManager.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleManager.scala index df72287d..841cc141 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleManager.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleManager.scala @@ -1,53 +1,78 @@ /* -* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. -* See file LICENSE for terms. -*/ + * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ package org.apache.spark.shuffle -import org.apache.spark.shuffle.compat.spark_2_4.{UcxShuffleBlockResolver, UcxShuffleReader} +import org.apache.spark.shuffle.compat.spark_2_4.{ + UcxShuffleBlockResolver, + UcxShuffleReader +} import org.apache.spark.util.ShutdownHookManager import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext} -/** - * Main entry point of Ucx shuffle plugin. It extends spark's default SortShufflePlugin - * and injects needed logic in override methods. - */ -class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean) extends CommonUcxShuffleManager(conf, isDriver) { +/** Main entry point of Ucx shuffle plugin. It extends spark's default SortShufflePlugin + * and injects needed logic in override methods. + */ +class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean) + extends CommonUcxShuffleManager(conf, isDriver) { ShutdownHookManager.addShutdownHook(Int.MaxValue - 1)(stop) - /** - * Register a shuffle with the manager and obtain a handle for it to pass to tasks. - * Called on driver and guaranteed by spark that shuffle on executor will start after it. - */ - override def registerShuffle[K, V, C](shuffleId: ShuffleId, - numMaps: Int, - dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + /** Register a shuffle with the manager and obtain a handle for it to pass to tasks. + * Called on driver and guaranteed by spark that shuffle on executor will start after it. + */ + override def registerShuffle[K, V, C]( + shuffleId: ShuffleId, + numMaps: Int, + dependency: ShuffleDependency[K, V, C] + ): ShuffleHandle = { assume(isDriver) - val baseHandle = super.registerShuffle(shuffleId, numMaps, dependency).asInstanceOf[BaseShuffleHandle[K, V, C]] + val baseHandle = super + .registerShuffle(shuffleId, numMaps, dependency) + .asInstanceOf[BaseShuffleHandle[K, V, C]] registerShuffleCommon(baseHandle, shuffleId, numMaps) } - /** - * Mapper callback on executor. Just start UcxNode and use Spark mapper logic. - */ - override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, - context: TaskContext): ShuffleWriter[K, V] = { + /** Mapper callback on executor. Just start UcxNode and use Spark mapper logic. + */ + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Int, + context: TaskContext + ): ShuffleWriter[K, V] = { startUcxNodeIfMissing() - shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K,V,_]]) - super.getWriter(handle.asInstanceOf[UcxShuffleHandle[K,V,_]].baseHandle, mapId, context) + shuffleIdToHandle.putIfAbsent( + handle.shuffleId, + handle.asInstanceOf[UcxShuffleHandle[K, V, _]] + ) + super.getWriter( + handle.asInstanceOf[UcxShuffleHandle[K, V, _]].baseHandle, + mapId, + context + ) } - override val shuffleBlockResolver: UcxShuffleBlockResolver = new UcxShuffleBlockResolver(this) + override val shuffleBlockResolver: UcxShuffleBlockResolver = + new UcxShuffleBlockResolver(this) - /** - * Reducer callback on executor. - */ - override def getReader[K, C](handle: ShuffleHandle, startPartition: Int, - endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { + /** Reducer callback on executor. + */ + override def getReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext + ): ShuffleReader[K, C] = { startUcxNodeIfMissing() - shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K,_,C]]) - new UcxShuffleReader(handle.asInstanceOf[UcxShuffleHandle[K,_,C]], startPartition, - endPartition, context) + shuffleIdToHandle.putIfAbsent( + handle.shuffleId, + handle.asInstanceOf[UcxShuffleHandle[K, _, C]] + ) + new UcxShuffleReader( + handle.asInstanceOf[UcxShuffleHandle[K, _, C]], + startPartition, + endPartition, + context + ) } } - diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleReader.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleReader.scala index a1ad2294..474c4988 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleReader.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleReader.scala @@ -1,146 +1,180 @@ /* -* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. -* See file LICENSE for terms. -*/ + * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ package org.apache.spark.shuffle.compat.spark_2_4 import java.io.InputStream import java.util.concurrent.LinkedBlockingQueue -import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext} +import org.apache.spark.{ + InterruptibleIterator, + MapOutputTracker, + SparkEnv, + TaskContext +} import org.apache.spark.internal.{Logging, config} import org.apache.spark.serializer.SerializerManager -import org.apache.spark.shuffle.{ShuffleReader, UcxShuffleHandle, UcxShuffleManager} +import org.apache.spark.shuffle.{ + ShuffleReader, + UcxShuffleHandle, + UcxShuffleManager +} import org.apache.spark.shuffle.ucx.reducer.compat.spark_2_4.UcxShuffleClient -import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator} +import org.apache.spark.storage.{ + BlockId, + BlockManager, + ShuffleBlockFetcherIterator +} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter -/** - * Extension of Spark's shuffe reader with a logic of injection UcxShuffleClient, - * and lazy progress only when result queue is empty. - */ -class UcxShuffleReader[K, C](handle: UcxShuffleHandle[K, _, C], - startPartition: Int, - endPartition: Int, - context: TaskContext, - serializerManager: SerializerManager = SparkEnv.get.serializerManager, - blockManager: BlockManager = SparkEnv.get.blockManager, - mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) - extends ShuffleReader[K, C] with Logging { +/** Extension of Spark's shuffe reader with a logic of injection UcxShuffleClient, + * and lazy progress only when result queue is empty. + */ +class UcxShuffleReader[K, C]( + handle: UcxShuffleHandle[K, _, C], + startPartition: Int, + endPartition: Int, + context: TaskContext, + serializerManager: SerializerManager = SparkEnv.get.serializerManager, + blockManager: BlockManager = SparkEnv.get.blockManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker +) extends ShuffleReader[K, C] + with Logging { - private val dep = handle.baseHandle.dependency + private val dep = handle.baseHandle.dependency - /** Read the combined key-values for this reduce task */ - override def read(): Iterator[Product2[K, C]] = { - val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics() - val workerWrapper = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager] - .ucxNode.getThreadLocalWorker - val shuffleClient = new UcxShuffleClient(shuffleMetrics, workerWrapper) - val wrappedStreams = new ShuffleBlockFetcherIterator( - context, - shuffleClient, - blockManager, - mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, - startPartition, endPartition), - serializerManager.wrapStream, - // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, - SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), - SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), - SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), - SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true)) + /** Read the combined key-values for this reduce task */ + override def read(): Iterator[Product2[K, C]] = { + val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics() + val workerWrapper = SparkEnv.get.shuffleManager + .asInstanceOf[UcxShuffleManager] + .ucxNode + .getThreadLocalWorker + val shuffleClient = new UcxShuffleClient(shuffleMetrics, workerWrapper) + val wrappedStreams = new ShuffleBlockFetcherIterator( + context, + shuffleClient, + blockManager, + mapOutputTracker.getMapSizesByExecutorId( + handle.shuffleId, + startPartition, + endPartition + ), + serializerManager.wrapStream, + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf + .getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, + SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), + SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), + SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), + SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true) + ) - // Ucx shuffle logic - // Java reflection to get access to private results queue - val queueField = wrappedStreams.getClass.getDeclaredField( - "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results") - queueField.setAccessible(true) - val resultQueue = queueField.get(wrappedStreams).asInstanceOf[LinkedBlockingQueue[_]] + // Ucx shuffle logic + // Java reflection to get access to private results queue + val queueField = wrappedStreams.getClass.getDeclaredField( + "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results" + ) + queueField.setAccessible(true) + val resultQueue = + queueField.get(wrappedStreams).asInstanceOf[LinkedBlockingQueue[_]] - // Do progress if queue is empty before calling next on ShuffleIterator - val ucxWrappedStream = new Iterator[(BlockId, InputStream)] { - override def next(): (BlockId, InputStream) = { - val startTime = System.currentTimeMillis() - workerWrapper.fillQueueWithBlocks(resultQueue) - shuffleMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime) - wrappedStreams.next() - } + // Do progress if queue is empty before calling next on ShuffleIterator + val ucxWrappedStream = new Iterator[(BlockId, InputStream)] { + override def next(): (BlockId, InputStream) = { + val startTime = System.currentTimeMillis() + workerWrapper.fillQueueWithBlocks(resultQueue) + shuffleMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime) + wrappedStreams.next() + } - override def hasNext: Boolean = { - val result = wrappedStreams.hasNext - if (!result) { - shuffleClient.close() - } - result + override def hasNext: Boolean = { + val result = wrappedStreams.hasNext + if (!result) { + shuffleClient.close() } + result } - // End of ucx shuffle logic + } + // End of ucx shuffle logic - val serializerInstance = dep.serializer.newInstance() - val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) => - // Note: the asKeyValueIterator below wraps a key/value iterator inside of a - // NextIterator. The NextIterator makes sure that close() is called on the - // underlying InputStream when all records have been read. - serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator - } + val serializerInstance = dep.serializer.newInstance() + val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) => + // Note: the asKeyValueIterator below wraps a key/value iterator inside of a + // NextIterator. The NextIterator makes sure that close() is called on the + // underlying InputStream when all records have been read. + serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + } - // Update the context task metrics for each record read. - val readMetrics = context.taskMetrics.createTempShuffleReadMetrics() - val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( - recordIter.map { record => - readMetrics.incRecordsRead(1) - record - }, - context.taskMetrics().mergeShuffleReadMetrics()) + // Update the context task metrics for each record read. + val readMetrics = context.taskMetrics.createTempShuffleReadMetrics() + val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( + recordIter.map { record => + readMetrics.incRecordsRead(1) + record + }, + context.taskMetrics().mergeShuffleReadMetrics() + ) - // An interruptible iterator must be used here in order to support task cancellation - val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) + // An interruptible iterator must be used here in order to support task cancellation + val interruptibleIter = + new InterruptibleIterator[(Any, Any)](context, metricIter) - val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { + val aggregatedIter: Iterator[Product2[K, C]] = + if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { // We are reading values that are already combined - val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] - dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) + val combinedKeyValuesIterator = + interruptibleIter.asInstanceOf[Iterator[(K, C)]] + dep.aggregator.get + .combineCombinersByKey(combinedKeyValuesIterator, context) } else { // We don't know the value type, but also don't care -- the dependency *should* // have made sure its compatible w/ this aggregator, which will convert the value // type to the combined type C - val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] + val keyValuesIterator = + interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) } } else { interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] } - // Sort the output if there is a sort ordering defined. - val resultIter = dep.keyOrdering match { - case Some(keyOrd: Ordering[K]) => - // Create an ExternalSorter to sort the data. - val sorter = - new ExternalSorter[K, C, C](context, - ordering = Some(keyOrd), serializer = dep.serializer) - sorter.insertAll(aggregatedIter) - context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) - context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) - context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) - // Use completion callback to stop sorter if task was finished/cancelled. - context.addTaskCompletionListener[Unit](_ => { - sorter.stop() - }) - CompletionIterator[Product2[K, C], - Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) - case None => - aggregatedIter - } + // Sort the output if there is a sort ordering defined. + val resultIter = dep.keyOrdering match { + case Some(keyOrd: Ordering[K]) => + // Create an ExternalSorter to sort the data. + val sorter = + new ExternalSorter[K, C, C]( + context, + ordering = Some(keyOrd), + serializer = dep.serializer + ) + sorter.insertAll(aggregatedIter) + context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) + context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) + // Use completion callback to stop sorter if task was finished/cancelled. + context.addTaskCompletionListener[Unit](_ => { + sorter.stop() + }) + CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]]( + sorter.iterator, + sorter.stop() + ) + case None => + aggregatedIter + } - resultIter match { - case _: InterruptibleIterator[Product2[K, C]] => resultIter - case _ => - // Use another interruptible iterator here to support task cancellation as aggregator - // or(and) sorter may have consumed previous interruptible iterator. - new InterruptibleIterator[Product2[K, C]](context, resultIter) - } + resultIter match { + case _: InterruptibleIterator[Product2[K, C]] => resultIter + case _ => + // Use another interruptible iterator here to support task cancellation as aggregator + // or(and) sorter may have consumed previous interruptible iterator. + new InterruptibleIterator[Product2[K, C]](context, resultIter) } + } } diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleDataIO.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleDataIO.scala index 6ecb74cd..2d961046 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleDataIO.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleDataIO.scala @@ -1,7 +1,7 @@ /* -* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. -* See file LICENSE for terms. -*/ + * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ package org.apache.spark.shuffle.compat.spark_3_0 import org.apache.spark.SparkConf @@ -9,10 +9,11 @@ import org.apache.spark.internal.Logging import org.apache.spark.shuffle.api.ShuffleExecutorComponents import org.apache.spark.shuffle.sort.io.LocalDiskShuffleDataIO -/** - * Ucx local disk IO plugin to handle logic of writing to local disk and shuffle memory registration. - */ -case class UcxLocalDiskShuffleDataIO(sparkConf: SparkConf) extends LocalDiskShuffleDataIO(sparkConf) with Logging { +/** Ucx local disk IO plugin to handle logic of writing to local disk and shuffle memory registration. + */ +case class UcxLocalDiskShuffleDataIO(sparkConf: SparkConf) + extends LocalDiskShuffleDataIO(sparkConf) + with Logging { override def executor(): ShuffleExecutorComponents = { new UcxLocalDiskShuffleExecutorComponents(sparkConf) diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleExecutorComponents.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleExecutorComponents.scala index b8e5e72e..d6091e43 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleExecutorComponents.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleExecutorComponents.scala @@ -1,7 +1,7 @@ /* -* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. -* See file LICENSE for terms. -*/ + * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ package org.apache.spark.shuffle.compat.spark_3_0 import java.util @@ -9,39 +9,67 @@ import java.util.Optional import org.apache.spark.internal.Logging import org.apache.spark.{SparkConf, SparkEnv} -import org.apache.spark.shuffle.sort.io.{LocalDiskShuffleExecutorComponents, LocalDiskShuffleMapOutputWriter, LocalDiskSingleSpillMapOutputWriter} +import org.apache.spark.shuffle.sort.io.{ + LocalDiskShuffleExecutorComponents, + LocalDiskShuffleMapOutputWriter, + LocalDiskSingleSpillMapOutputWriter +} import org.apache.spark.shuffle.UcxShuffleManager -import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, SingleSpillShuffleMapOutputWriter} +import org.apache.spark.shuffle.api.{ + ShuffleMapOutputWriter, + SingleSpillShuffleMapOutputWriter +} -/** - * Entry point to UCX executor. - */ +/** Entry point to UCX executor. + */ class UcxLocalDiskShuffleExecutorComponents(sparkConf: SparkConf) - extends LocalDiskShuffleExecutorComponents(sparkConf) with Logging{ + extends LocalDiskShuffleExecutorComponents(sparkConf) + with Logging { private var blockResolver: UcxShuffleBlockResolver = _ - override def initializeExecutor(appId: String, execId: String, extraConfigs: util.Map[String, String]): Unit = { - val ucxShuffleManager = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager] + override def initializeExecutor( + appId: String, + execId: String, + extraConfigs: util.Map[String, String] + ): Unit = { + val ucxShuffleManager = + SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager] ucxShuffleManager.startUcxNodeIfMissing() blockResolver = ucxShuffleManager.shuffleBlockResolver } - override def createMapOutputWriter(shuffleId: Int, mapTaskId: Long, numPartitions: Int): ShuffleMapOutputWriter = { + override def createMapOutputWriter( + shuffleId: Int, + mapTaskId: Long, + numPartitions: Int + ): ShuffleMapOutputWriter = { if (blockResolver == null) { throw new IllegalStateException( - "Executor components must be initialized before getting writers.") + "Executor components must be initialized before getting writers." + ) } new LocalDiskShuffleMapOutputWriter( - shuffleId, mapTaskId, numPartitions, blockResolver, sparkConf) + shuffleId, + mapTaskId, + numPartitions, + blockResolver, + sparkConf + ) } - override def createSingleFileMapOutputWriter(shuffleId: Int, mapId: Long): Optional[SingleSpillShuffleMapOutputWriter] = { + override def createSingleFileMapOutputWriter( + shuffleId: Int, + mapId: Long + ): Optional[SingleSpillShuffleMapOutputWriter] = { if (blockResolver == null) { throw new IllegalStateException( - "Executor components must be initialized before getting writers.") + "Executor components must be initialized before getting writers." + ) } - Optional.of(new LocalDiskSingleSpillMapOutputWriter(shuffleId, mapId, blockResolver)) + Optional.of( + new LocalDiskSingleSpillMapOutputWriter(shuffleId, mapId, blockResolver) + ) } } diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleBlockResolver.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleBlockResolver.scala index 8e172f96..d2a762da 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleBlockResolver.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleBlockResolver.scala @@ -1,7 +1,7 @@ /* -* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. -* See file LICENSE for terms. -*/ + * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ package org.apache.spark.shuffle.compat.spark_3_0 import java.io.{File, RandomAccessFile} @@ -9,29 +9,39 @@ import java.io.{File, RandomAccessFile} import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.network.shuffle.ExecutorDiskUtils import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID -import org.apache.spark.shuffle.{CommonUcxShuffleBlockResolver, CommonUcxShuffleManager} +import org.apache.spark.shuffle.{ + CommonUcxShuffleBlockResolver, + CommonUcxShuffleManager +} import org.apache.spark.storage.ShuffleIndexBlockId -/** - * Mapper entry point for UcxShuffle plugin. Performs memory registration - * of data and index files and publish addresses to driver metadata buffer. - */ +/** Mapper entry point for UcxShuffle plugin. Performs memory registration + * of data and index files and publish addresses to driver metadata buffer. + */ class UcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager) - extends CommonUcxShuffleBlockResolver(ucxShuffleManager) { + extends CommonUcxShuffleBlockResolver(ucxShuffleManager) { private def getIndexFile( - shuffleId: Int, - mapId: Long, - dirs: Option[Array[String]] = None): File = { + shuffleId: Int, + mapId: Long, + dirs: Option[Array[String]] = None + ): File = { val blockId = ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID) val blockManager = SparkEnv.get.blockManager dirs - .map(ExecutorDiskUtils.getFile(_, blockManager.subDirsPerLocalDir, blockId.name)) + .map( + ExecutorDiskUtils + .getFile(_, blockManager.subDirsPerLocalDir, blockId.name) + ) .getOrElse(blockManager.diskBlockManager.getFile(blockId)) } - override def writeIndexFileAndCommit(shuffleId: ShuffleId, mapId: Long, - lengths: Array[Long], dataTmp: File): Unit = { + override def writeIndexFileAndCommit( + shuffleId: ShuffleId, + mapId: Long, + lengths: Array[Long], + dataTmp: File + ): Unit = { super.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp) // In Spark-3.0 MapId is long and unique among all jobs in spark. We need to use partitionId as offset // in metadata buffer @@ -47,6 +57,13 @@ class UcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager) val indexFile = getIndexFile(shuffleId, mapId) val indexBackFile = new RandomAccessFile(indexFile, "rw") - writeIndexFileAndCommitCommon(shuffleId, partitionId, lengths, dataTmp, indexBackFile, dataBackFile) + writeIndexFileAndCommitCommon( + shuffleId, + partitionId, + lengths, + dataTmp, + indexBackFile, + dataBackFile + ) } } diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleManager.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleManager.scala index 4d4f8c09..2e504c7a 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleManager.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleManager.scala @@ -1,40 +1,64 @@ /* -* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. -* See file LICENSE for terms. -*/ + * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ package org.apache.spark.shuffle import scala.collection.JavaConverters._ import org.apache.spark.shuffle.api.ShuffleExecutorComponents -import org.apache.spark.shuffle.compat.spark_3_0.{UcxShuffleBlockResolver, UcxShuffleReader} -import org.apache.spark.shuffle.sort.{SerializedShuffleHandle, SortShuffleWriter, UnsafeShuffleWriter} +import org.apache.spark.shuffle.compat.spark_3_0.{ + UcxShuffleBlockResolver, + UcxShuffleReader +} +import org.apache.spark.shuffle.sort.{ + SerializedShuffleHandle, + SortShuffleWriter, + UnsafeShuffleWriter +} import org.apache.spark.util.ShutdownHookManager import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext} -/** - * Main entry point of Ucx shuffle plugin. It extends spark's default SortShufflePlugin - * and injects needed logic in override methods. - */ -class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean) extends CommonUcxShuffleManager(conf, isDriver) { +/** Main entry point of Ucx shuffle plugin. It extends spark's default SortShufflePlugin + * and injects needed logic in override methods. + */ +class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean) + extends CommonUcxShuffleManager(conf, isDriver) { ShutdownHookManager.addShutdownHook(Int.MaxValue - 1)(stop) - private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) + private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents( + conf + ) override val shuffleBlockResolver = new UcxShuffleBlockResolver(this) - override def registerShuffle[K, V, C](shuffleId: ShuffleId, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + override def registerShuffle[K, V, C]( + shuffleId: ShuffleId, + dependency: ShuffleDependency[K, V, C] + ): ShuffleHandle = { assume(isDriver) - val numMaps = dependency.partitioner.numPartitions - val baseHandle = super.registerShuffle(shuffleId, dependency).asInstanceOf[BaseShuffleHandle[K, V, C]] + val numMaps = dependency.rdd.getNumPartitions + val baseHandle = super + .registerShuffle(shuffleId, dependency) + .asInstanceOf[BaseShuffleHandle[K, V, C]] registerShuffleCommon(baseHandle, shuffleId, numMaps) } - override def getWriter[K, V](handle: ShuffleHandle, mapId: Long, context: TaskContext, - metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { - shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K, V, _]]) + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Long, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter + ): ShuffleWriter[K, V] = { + shuffleIdToHandle.putIfAbsent( + handle.shuffleId, + handle.asInstanceOf[UcxShuffleHandle[K, V, _]] + ) val env = SparkEnv.get handle.asInstanceOf[UcxShuffleHandle[K, V, _]].baseHandle match { - case unsafeShuffleHandle: SerializedShuffleHandle[K@unchecked, V@unchecked] => + case unsafeShuffleHandle: SerializedShuffleHandle[ + K @unchecked, + V @unchecked + ] => new UnsafeShuffleWriter( env.blockManager, context.taskMemoryManager(), @@ -43,31 +67,54 @@ class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean) extends context, env.conf, metrics, - shuffleExecutorComponents) - case other: BaseShuffleHandle[K@unchecked, V@unchecked, _] => + shuffleExecutorComponents + ) + case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => new SortShuffleWriter( - shuffleBlockResolver, other, mapId, context, shuffleExecutorComponents) + shuffleBlockResolver, + other, + mapId, + context, + shuffleExecutorComponents + ) } } - override def getReader[K, C](handle: ShuffleHandle, startPartition: MapId, endPartition: MapId, - context: TaskContext, metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + override def getReader[K, C]( + handle: ShuffleHandle, + startPartition: MapId, + endPartition: MapId, + context: TaskContext, + metrics: ShuffleReadMetricsReporter + ): ShuffleReader[K, C] = { startUcxNodeIfMissing() - shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K, _, C]]) - new UcxShuffleReader(handle.asInstanceOf[UcxShuffleHandle[K,_,C]], startPartition, endPartition, - context, readMetrics = metrics, shouldBatchFetch = true) + shuffleIdToHandle.putIfAbsent( + handle.shuffleId, + handle.asInstanceOf[UcxShuffleHandle[K, _, C]] + ) + new UcxShuffleReader( + handle.asInstanceOf[UcxShuffleHandle[K, _, C]], + startPartition, + endPartition, + context, + readMetrics = metrics, + shouldBatchFetch = true + ) } - - private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = { - val executorComponents = ShuffleDataIOUtils.loadShuffleDataIO(conf).executor() - val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX) - .toMap + private def loadShuffleExecutorComponents( + conf: SparkConf + ): ShuffleExecutorComponents = { + val executorComponents = + ShuffleDataIOUtils.loadShuffleDataIO(conf).executor() + val extraConfigs = + conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX).toMap executorComponents.initializeExecutor( conf.getAppId, SparkEnv.get.executorId, - extraConfigs.asJava) + extraConfigs.asJava + ) executorComponents } diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleReader.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleReader.scala index 7dc34a7b..5eed9d37 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleReader.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleReader.scala @@ -1,7 +1,7 @@ /* -* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. -* See file LICENSE for terms. -*/ + * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ package org.apache.spark.shuffle.compat.spark_3_0 import java.io.InputStream @@ -13,161 +13,215 @@ import org.apache.spark.internal.{Logging, config} import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.SerializerManager import org.apache.spark.shuffle.ucx.reducer.compat.spark_3_0.UcxShuffleClient -import org.apache.spark.shuffle.{ShuffleReadMetricsReporter, ShuffleReader, UcxShuffleHandle, UcxShuffleManager} -import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockBatchId, ShuffleBlockFetcherIterator, ShuffleBlockId} +import org.apache.spark.shuffle.{ + ShuffleReadMetricsReporter, + ShuffleReader, + UcxShuffleHandle, + UcxShuffleManager +} +import org.apache.spark.storage.{ + BlockId, + BlockManager, + ShuffleBlockBatchId, + ShuffleBlockFetcherIterator, + ShuffleBlockId +} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter -import org.apache.spark.{InterruptibleIterator, SparkEnv, SparkException, TaskContext} - +import org.apache.spark.{ + InterruptibleIterator, + SparkEnv, + SparkException, + TaskContext +} -/** - * Extension of Spark's shuffe reader with a logic of injection UcxShuffleClient, - * and lazy progress only when result queue is empty. - */ -class UcxShuffleReader[K, C](handle: UcxShuffleHandle[K, _, C], - startPartition: Int, - endPartition: Int, - context: TaskContext, - serializerManager: SerializerManager = SparkEnv.get.serializerManager, - blockManager: BlockManager = SparkEnv.get.blockManager, - readMetrics: ShuffleReadMetricsReporter, - shouldBatchFetch: Boolean = false) extends ShuffleReader[K, C] with Logging { - - private val dep = handle.baseHandle.dependency - - /** Read the combined key-values for this reduce task */ - override def read(): Iterator[Product2[K, C]] = { - val (blocksByAddressIterator1, blocksByAddressIterator2) = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( - handle.shuffleId, startPartition, endPartition).duplicate - val mapIdToBlockIndex = blocksByAddressIterator2.flatMap{ - case (_, blocks) => blocks.map { - case (blockId, _, mapIdx) => blockId match { - case x: ShuffleBlockId => (x.mapId.asInstanceOf[java.lang.Long], mapIdx.asInstanceOf[java.lang.Integer]) - case x: ShuffleBlockBatchId => (x.mapId.asInstanceOf[java.lang.Long], mapIdx.asInstanceOf[java.lang.Integer]) +/** Extension of Spark's shuffe reader with a logic of injection UcxShuffleClient, + * and lazy progress only when result queue is empty. + */ +class UcxShuffleReader[K, C]( + handle: UcxShuffleHandle[K, _, C], + startPartition: Int, + endPartition: Int, + context: TaskContext, + serializerManager: SerializerManager = SparkEnv.get.serializerManager, + blockManager: BlockManager = SparkEnv.get.blockManager, + readMetrics: ShuffleReadMetricsReporter, + shouldBatchFetch: Boolean = false +) extends ShuffleReader[K, C] + with Logging { + + private val dep = handle.baseHandle.dependency + + /** Read the combined key-values for this reduce task */ + override def read(): Iterator[Product2[K, C]] = { + val (blocksByAddressIterator1, blocksByAddressIterator2) = + SparkEnv.get.mapOutputTracker + .getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition) + .duplicate + val mapIdToBlockIndex = blocksByAddressIterator2.flatMap { + case (_, blocks) => + blocks.map { case (blockId, _, mapIdx) => + blockId match { + case x: ShuffleBlockId => + ( + x.mapId.asInstanceOf[java.lang.Long], + mapIdx.asInstanceOf[java.lang.Integer] + ) + case x: ShuffleBlockBatchId => + ( + x.mapId.asInstanceOf[java.lang.Long], + mapIdx.asInstanceOf[java.lang.Integer] + ) case _ => throw new SparkException("Unknown block") } } - }.toMap - - val workerWrapper = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager] - .ucxNode.getThreadLocalWorker - val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics() - val shuffleClient = new UcxShuffleClient(handle.shuffleId, workerWrapper, mapIdToBlockIndex.asJava, shuffleMetrics) - val shuffleIterator = new ShuffleBlockFetcherIterator( - context, - shuffleClient, - blockManager, - blocksByAddressIterator1, - serializerManager.wrapStream, - // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, - SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT), - SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), - SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), - SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), - SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), - readMetrics, - fetchContinuousBlocksInBatch) - - val wrappedStreams = shuffleIterator.toCompletionIterator - - // Ucx shuffle logic - // Java reflection to get access to private results queue - val queueField = shuffleIterator.getClass.getDeclaredField( - "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results") - queueField.setAccessible(true) - val resultQueue = queueField.get(shuffleIterator).asInstanceOf[LinkedBlockingQueue[_]] - - // Do progress if queue is empty before calling next on ShuffleIterator - val ucxWrappedStream = new Iterator[(BlockId, InputStream)] { - override def next(): (BlockId, InputStream) = { - val startTime = System.currentTimeMillis() - workerWrapper.fillQueueWithBlocks(resultQueue) - readMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime) - wrappedStreams.next() - } + }.toMap + + val workerWrapper = SparkEnv.get.shuffleManager + .asInstanceOf[UcxShuffleManager] + .ucxNode + .getThreadLocalWorker + val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics() + val shuffleClient = new UcxShuffleClient( + handle.shuffleId, + workerWrapper, + mapIdToBlockIndex.asJava, + shuffleMetrics + ) + val shuffleIterator = new ShuffleBlockFetcherIterator( + context, + shuffleClient, + blockManager, + blocksByAddressIterator1, + serializerManager.wrapStream, + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, + SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT), + SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), + SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), + readMetrics, + fetchContinuousBlocksInBatch + ) + + val wrappedStreams = shuffleIterator.toCompletionIterator + + // Ucx shuffle logic + // Java reflection to get access to private results queue + val queueField = shuffleIterator.getClass.getDeclaredField( + "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results" + ) + queueField.setAccessible(true) + val resultQueue = + queueField.get(shuffleIterator).asInstanceOf[LinkedBlockingQueue[_]] + + // Do progress if queue is empty before calling next on ShuffleIterator + val ucxWrappedStream = new Iterator[(BlockId, InputStream)] { + override def next(): (BlockId, InputStream) = { + val startTime = System.currentTimeMillis() + workerWrapper.fillQueueWithBlocks(resultQueue) + readMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime) + wrappedStreams.next() + } - override def hasNext: Boolean = { - val result = wrappedStreams.hasNext - if (!result) { - shuffleClient.close() - } - result + override def hasNext: Boolean = { + val result = wrappedStreams.hasNext + if (!result) { + shuffleClient.close() } + result } - // End of ucx shuffle logic - - val serializerInstance = dep.serializer.newInstance() - - // Create a key/value iterator for each stream - val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) => - // Note: the asKeyValueIterator below wraps a key/value iterator inside of a - // NextIterator. The NextIterator makes sure that close() is called on the - // underlying InputStream when all records have been read. - serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator - } + } + // End of ucx shuffle logic - // Update the context task metrics for each record read. - val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( - recordIter.map { record => - readMetrics.incRecordsRead(1) - record - }, - context.taskMetrics().mergeShuffleReadMetrics()) + val serializerInstance = dep.serializer.newInstance() - // An interruptible iterator must be used here in order to support task cancellation - val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) + // Create a key/value iterator for each stream + val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) => + // Note: the asKeyValueIterator below wraps a key/value iterator inside of a + // NextIterator. The NextIterator makes sure that close() is called on the + // underlying InputStream when all records have been read. + serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + } - val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { + // Update the context task metrics for each record read. + val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( + recordIter.map { record => + readMetrics.incRecordsRead(1) + record + }, + context.taskMetrics().mergeShuffleReadMetrics() + ) + + // An interruptible iterator must be used here in order to support task cancellation + val interruptibleIter = + new InterruptibleIterator[(Any, Any)](context, metricIter) + + val aggregatedIter: Iterator[Product2[K, C]] = + if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { // We are reading values that are already combined - val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] - dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) + val combinedKeyValuesIterator = + interruptibleIter.asInstanceOf[Iterator[(K, C)]] + dep.aggregator.get + .combineCombinersByKey(combinedKeyValuesIterator, context) } else { // We don't know the value type, but also don't care -- the dependency *should* // have made sure its compatible w/ this aggregator, which will convert the value // type to the combined type C - val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] + val keyValuesIterator = + interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) } } else { interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] } - // Sort the output if there is a sort ordering defined. - val resultIter = dep.keyOrdering match { - case Some(keyOrd: Ordering[K]) => - // Create an ExternalSorter to sort the data. - val sorter = - new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer) - sorter.insertAll(aggregatedIter) - context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) - context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) - context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) - // Use completion callback to stop sorter if task was finished/cancelled. - context.addTaskCompletionListener[Unit](_ => { - sorter.stop() - }) - CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) - case None => - aggregatedIter - } + // Sort the output if there is a sort ordering defined. + val resultIter = dep.keyOrdering match { + case Some(keyOrd: Ordering[K]) => + // Create an ExternalSorter to sort the data. + val sorter = + new ExternalSorter[K, C, C]( + context, + ordering = Some(keyOrd), + serializer = dep.serializer + ) + sorter.insertAll(aggregatedIter) + context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) + context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) + // Use completion callback to stop sorter if task was finished/cancelled. + context.addTaskCompletionListener[Unit](_ => { + sorter.stop() + }) + CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]]( + sorter.iterator, + sorter.stop() + ) + case None => + aggregatedIter + } - resultIter match { - case _: InterruptibleIterator[Product2[K, C]] => resultIter - case _ => - // Use another interruptible iterator here to support task cancellation as aggregator - // or(and) sorter may have consumed previous interruptible iterator. - new InterruptibleIterator[Product2[K, C]](context, resultIter) - } + resultIter match { + case _: InterruptibleIterator[Product2[K, C]] => resultIter + case _ => + // Use another interruptible iterator here to support task cancellation as aggregator + // or(and) sorter may have consumed previous interruptible iterator. + new InterruptibleIterator[Product2[K, C]](context, resultIter) } + } private def fetchContinuousBlocksInBatch: Boolean = { val conf = SparkEnv.get.conf - val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects + val serializerRelocatable = + dep.serializer.supportsRelocationOfSerializedObjects val compressed = conf.get(config.SHUFFLE_COMPRESS) val codecConcatenation = if (compressed) { - CompressionCodec.supportsConcatenationOfSerializedStreams(CompressionCodec.createCodec(conf)) + CompressionCodec.supportsConcatenationOfSerializedStreams( + CompressionCodec.createCodec(conf) + ) } else { true } @@ -176,12 +230,14 @@ class UcxShuffleReader[K, C](handle: UcxShuffleHandle[K, _, C], val doBatchFetch = shouldBatchFetch && serializerRelocatable && (!compressed || codecConcatenation) && !useOldFetchProtocol if (shouldBatchFetch && !doBatchFetch) { - logWarning("The feature tag of continuous shuffle block fetching is set to true, but " + - "we can not enable the feature because other conditions are not satisfied. " + - s"Shuffle compress: $compressed, serializer ${dep.serializer.getClass.getName} " + - s"relocatable: $serializerRelocatable, " + - s"codec concatenation: $codecConcatenation, use old shuffle fetch protocol: " + - s"$useOldFetchProtocol.") + logWarning( + "The feature tag of continuous shuffle block fetching is set to true, but " + + "we can not enable the feature because other conditions are not satisfied. " + + s"Shuffle compress: $compressed, serializer ${dep.serializer.getClass.getName} " + + s"relocatable: $serializerRelocatable, " + + s"codec concatenation: $codecConcatenation, use old shuffle fetch protocol: " + + s"$useOldFetchProtocol." + ) } doBatchFetch }