Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] large concurrency service broken #3005

Open
fan-niu opened this issue Oct 31, 2024 · 4 comments
Open

[Bug] large concurrency service broken #3005

fan-niu opened this issue Oct 31, 2024 · 4 comments
Labels
bug Confirmed bugs

Comments

@fan-niu
Copy link

fan-niu commented Oct 31, 2024

🐛 Bug

The service started based on Meta-Llama-3.1-70B-Instruct fp8 will crash when running a large concurrency.

To Reproduce

convert model

refer this issue: #2982

start service

mlc_llm serve Meta-Llama-3.1-70B-Instruct-fp8-e4m3_e4m3_f16 \
    --model-lib Meta-Llama-3.1-70B-Instruct-fp8-e4m3_e4m3_f16/libs/h100-e4m3_e4m3_f16-cuda.so \
    --mode server \
    --host 127.0.0.1 \
    --port 8081 \
    --device cuda \
    --prefix-cache-mode disable \
    --enable-debug \
    --overrides "tensor_parallel_shards=2"

concurrency test

service broken: use 70 concurrency and avg input tokens = 2310, avg output tokens = 50

error messages

==== backtrace (tid:   1147) ====
 0 0x0000000000042520 __sigaction()  ???:0
 1 0x0000000003200e60 tvm::runtime::relax_vm::PagedAttentionKVCacheObj::GetTotalSequenceLength()  ???:0
 2 0x00000000031e5274 tvm::runtime::TypedPackedFunc<int (tvm::runtime::relax_vm::AttentionKVCache)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::runtime::relax_vm::AttentionKVCache, tvm::runtime::relax_vm::AttentionKVCacheObj, int, , void>(int (tvm::runtime::relax_vm::AttentionKVCacheObj::*)() const)::{lambda(tvm::runtime::relax_vm::AttentionKVCache)#1}>(tvm::runtime::Registry::set_body_method<tvm::runtime::relax_vm::AttentionKVCache, tvm::runtime::relax_vm::AttentionKVCacheObj, int, , void>(int (tvm::runtime::relax_vm::AttentionKVCacheObj::*)() const)::{lambda(tvm::runtime::relax_vm::AttentionKVCache)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()()  ???:0
 3 0x00000000031e5325 tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<int (tvm::runtime::relax_vm::AttentionKVCache)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::runtime::relax_vm::AttentionKVCache, tvm::runtime::relax_vm::AttentionKVCacheObj, int, , void>(int (tvm::runtime::relax_vm::AttentionKVCacheObj::*)() const)::{lambda(tvm::runtime::relax_vm::AttentionKVCache)#1}>(tvm::runtime::Registry::set_body_method<tvm::runtime::relax_vm::AttentionKVCache, tvm::runtime::relax_vm::AttentionKVCacheObj, int, , void>(int (tvm::runtime::relax_vm::AttentionKVCacheObj::*)() const)::{lambda(tvm::runtime::relax_vm::AttentionKVCache)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call()  ???:0
 4 0x00000000002d597d tvm::runtime::PackedFuncObj::CallPacked()  /workspace/mlc-llm/3rdparty/tvm/include/tvm/runtime/packed_func.h:1398
 5 0x00000000002d597d tvm::runtime::PackedFunc::operator()<tvm::runtime::ObjectRef const&>()  /workspace/mlc-llm/3rdparty/tvm/include/tvm/runtime/packed_func.h:1932
 6 0x00000000002d597d mlc::llm::serve::ModelImpl::GetCurrentTotalSequenceLength()  /workspace/mlc-llm/cpp/serve/model.cc:775
 7 0x000000000027bdec mlc::llm::serve::BatchPrefillBaseActionObj::GetRequestStateEntriesToPrefill()  /workspace/mlc-llm/cpp/serve/engine_actions/batch_prefill_base.cc:91
 8 0x000000000027bdec tvm::runtime::ObjectPtr<tvm::runtime::Object>::~ObjectPtr()  /workspace/mlc-llm/3rdparty/tvm/include/tvm/runtime/object.h:404
 9 0x000000000027bdec tvm::runtime::ObjectRef::~ObjectRef()  /workspace/mlc-llm/3rdparty/tvm/include/tvm/runtime/object.h:519
10 0x000000000027bdec mlc::llm::serve::Model::~Model()  /workspace/mlc-llm/cpp/serve/engine_actions/../model.h:366
11 0x000000000027bdec mlc::llm::serve::BatchPrefillBaseActionObj::GetRequestStateEntriesToPrefill()  /workspace/mlc-llm/cpp/serve/engine_actions/batch_prefill_base.cc:91
12 0x000000000029d4ac mlc::llm::serve::NewRequestPrefillActionObj::Step()  /workspace/mlc-llm/cpp/serve/engine_actions/new_request_prefill.cc:35
13 0x000000000029d4ac std::_Vector_base<mlc::llm::serve::BatchPrefillBaseActionObj::PrefillInput, std::allocator<mlc::llm::serve::BatchPrefillBaseActionObj::PrefillInput> >::_Vector_impl_data::_M_swap_data()  /opt/rh/gcc-toolset-11/root/usr/include/c++/11/bits/stl_vector.h:123
14 0x000000000029d4ac std::vector<mlc::llm::serve::BatchPrefillBaseActionObj::PrefillInput, std::allocator<mlc::llm::serve::BatchPrefillBaseActionObj::PrefillInput> >::_M_move_assign()  /opt/rh/gcc-toolset-11/root/usr/include/c++/11/bits/stl_vector.h:1818
15 0x000000000029d4ac std::vector<mlc::llm::serve::BatchPrefillBaseActionObj::PrefillInput, std::allocator<mlc::llm::serve::BatchPrefillBaseActionObj::PrefillInput> >::operator=()  /opt/rh/gcc-toolset-11/root/usr/include/c++/11/bits/stl_vector.h:714
16 0x000000000029d4ac mlc::llm::serve::NewRequestPrefillActionObj::Step()  /workspace/mlc-llm/cpp/serve/engine_actions/new_request_prefill.cc:35
17 0x000000000024f9f9 mlc::llm::serve::EngineImpl::Step()  /workspace/mlc-llm/cpp/serve/engine.cc:594
18 0x000000000024f9f9 tvm::runtime::ObjectPtr<tvm::runtime::Object>::operator=()  /workspace/mlc-llm/3rdparty/tvm/include/tvm/runtime/object.h:444
19 0x000000000024f9f9 tvm::runtime::Array<mlc::llm::serve::Request, void>::operator=()  /workspace/mlc-llm/3rdparty/tvm/include/tvm/runtime/container/array.h:362
20 0x000000000024f9f9 mlc::llm::serve::EngineImpl::Step()  /workspace/mlc-llm/cpp/serve/engine.cc:594
21 0x000000000030a093 mlc::llm::serve::ThreadedEngineImpl::RunBackgroundLoop()  /workspace/mlc-llm/cpp/serve/threaded_engine.cc:182
22 0x000000000030a093 std::atomic<bool>::load()  /opt/rh/gcc-toolset-11/root/usr/include/c++/11/atomic:112
23 0x000000000030a093 mlc::llm::serve::ThreadedEngineImpl::RunBackgroundLoop()  /workspace/mlc-llm/cpp/serve/threaded_engine.cc:138
24 0x000000000313ceba TVMFuncCall()  ???:0
25 0x0000000000022215 __pyx_f_3tvm_4_ffi_4_cy3_4core_FuncCall()  core.cpp:0
26 0x0000000000022af8 __pyx_pw_3tvm_4_ffi_4_cy3_4core_14PackedFuncBase_5__call__()  core.cpp:0
27 0x000000000016942b PyObject_Call()  ???:0
28 0x00000000001455d7 _PyEval_EvalFrameDefault()  ???:0
29 0x000000000015a9fc _PyFunction_Vectorcall()  ???:0
30 0x000000000014345c _PyEval_EvalFrameDefault()  ???:0
31 0x000000000015a9fc _PyFunction_Vectorcall()  ???:0
32 0x000000000014345c _PyEval_EvalFrameDefault()  ???:0
33 0x0000000000168a51 PyMethod_New()  ???:0
34 0x0000000000291f3a _PyDict_SetItem_KnownHash()  ???:0
35 0x0000000000286ef8 _PyObject_RealIsInstance()  ???:0
36 0x0000000000094ac3 pthread_condattr_setpshared()  ???:0
37 0x0000000000126850 __xmknodat()  ???:0
=================================
run_service_fp8_70b.sh: line 22:   889 Segmentation fault      (core dumped) mlc_llm serve $MLC_MODEL --model-lib $MLC_LIB --mode server --host $SERVER_ADDR --port $SERVER_PORT --device cuda --prefix-cache-mode disable --enable-debug --overrides "tensor_parallel_shards=2"

Expected behavior

Is there any way to keep the service from crashing other than lowering the concurrency? For example, concurrency requests that cannot be processed are evicted or the response speed of the return is reduced.

Environment

  • Platform (e.g. WebGPU/Vulkan/IOS/Android/CUDA): CUDA
  • Operating system (e.g. Ubuntu/Windows/MacOS/...): Ubuntu
  • Device (e.g. iPhone 12 Pro, PC+RTX 3090, ...): H100
  • How you installed MLC-LLM (conda, source): conda
  • How you installed TVM-Unity (pip, source): pip
  • Python version (e.g. 3.10): Python 3.10.12
  • GPU driver version (if applicable): 535.129.03
  • CUDA/cuDNN version (if applicable): 12.3
  • TVM Unity Hash Tag (python -c "import tvm; print('\n'.join(f'{k}: {v}' for k, v in tvm.support.libinfo().items()))", applicable if you compile models):
USE_GTEST: AUTO
SUMMARIZE: OFF
TVM_DEBUG_WITH_ABI_CHANGE: OFF
USE_IOS_RPC: OFF
USE_MSC: OFF
USE_ETHOSU: 
CUDA_VERSION: 12.3
USE_LIBBACKTRACE: AUTO
DLPACK_PATH: 3rdparty/dlpack/include
USE_TENSORRT_CODEGEN: OFF
USE_THRUST: ON
USE_TARGET_ONNX: OFF
USE_AOT_EXECUTOR: ON
BUILD_DUMMY_LIBTVM: OFF
USE_CUDNN: OFF
USE_TENSORRT_RUNTIME: OFF
USE_ARM_COMPUTE_LIB_GRAPH_EXECUTOR: OFF
USE_CCACHE: AUTO
USE_ARM_COMPUTE_LIB: OFF
USE_CPP_RTVM: 
USE_OPENCL_GTEST: /path/to/opencl/gtest
TVM_LOG_BEFORE_THROW: OFF
USE_MKL: OFF
USE_PT_TVMDSOOP: OFF
MLIR_VERSION: NOT-FOUND
USE_CLML: OFF
USE_STACKVM_RUNTIME: OFF
USE_GRAPH_EXECUTOR_CUDA_GRAPH: OFF
ROCM_PATH: /opt/rocm
USE_DNNL: OFF
USE_MSCCL: OFF
USE_NNAPI_RUNTIME: OFF
USE_VITIS_AI: OFF
USE_MLIR: OFF
USE_RCCL: OFF
USE_LLVM: llvm-config --ignore-libllvm --link-static
USE_VERILATOR: OFF
USE_TF_TVMDSOOP: OFF
USE_THREADS: ON
USE_MSVC_MT: OFF
BACKTRACE_ON_SEGFAULT: OFF
USE_GRAPH_EXECUTOR: ON
USE_NCCL: ON
USE_ROCBLAS: OFF
GIT_COMMIT_HASH: dc87019cb805d0a1f0075f6415cc979ef337ec2a
USE_VULKAN: ON
USE_RUST_EXT: OFF
USE_CUTLASS: ON
USE_CPP_RPC: OFF
USE_HEXAGON: OFF
USE_CUSTOM_LOGGING: OFF
USE_UMA: OFF
USE_FALLBACK_STL_MAP: OFF
USE_SORT: ON
USE_RTTI: ON
GIT_COMMIT_TIME: 2024-09-28 00:31:12 -0400
USE_HIPBLAS: OFF
USE_HEXAGON_SDK: /path/to/sdk
USE_BLAS: none
USE_ETHOSN: OFF
USE_LIBTORCH: OFF
USE_RANDOM: ON
USE_CUDA: ON
USE_COREML: OFF
USE_AMX: OFF
BUILD_STATIC_RUNTIME: OFF
USE_CMSISNN: OFF
USE_KHRONOS_SPIRV: OFF
USE_CLML_GRAPH_EXECUTOR: OFF
USE_TFLITE: OFF
USE_HEXAGON_GTEST: /path/to/hexagon/gtest
PICOJSON_PATH: 3rdparty/picojson
USE_OPENCL_ENABLE_HOST_PTR: OFF
INSTALL_DEV: OFF
USE_PROFILER: ON
USE_NNPACK: OFF
LLVM_VERSION: 17.0.6
USE_MRVL: OFF
USE_OPENCL: OFF
COMPILER_RT_PATH: 3rdparty/compiler-rt
USE_NNAPI_CODEGEN: OFF
RANG_PATH: 3rdparty/rang/include
USE_SPIRV_KHR_INTEGER_DOT_PRODUCT: OFF
USE_OPENMP: OFF
USE_BNNS: OFF
USE_FLASHINFER: ON
USE_CUBLAS: ON
USE_METAL: OFF
USE_MICRO_STANDALONE_RUNTIME: OFF
USE_HEXAGON_EXTERNAL_LIBS: OFF
USE_ALTERNATIVE_LINKER: AUTO
USE_BYODT_POSIT: OFF
USE_NVSHMEM: OFF
USE_HEXAGON_RPC: OFF
USE_MICRO: OFF
DMLC_PATH: 3rdparty/dmlc-core/include
INDEX_DEFAULT_I64: ON
USE_RELAY_DEBUG: OFF
USE_RPC: ON
USE_TENSORFLOW_PATH: none
TVM_CLML_VERSION: 
USE_MIOPEN: OFF
USE_ROCM: OFF
USE_PAPI: OFF
USE_CURAND: OFF
TVM_CXX_COMPILER_PATH: /opt/rh/gcc-toolset-11/root/usr/bin/c++
HIDE_PRIVATE_SYMBOLS: ON
  • Any other relevant information:

Additional context

A low concurrency service is normal. A high concurrency service will crash the service. Hope the service will not crash.

@fan-niu fan-niu added the bug Confirmed bugs label Oct 31, 2024
@MasterJH5574
Copy link
Member

Thanks for reporting. We will find time and try to reproduce it. Meanwhile, may I ask how often this segmentation fault happens? Does it happen every time you use “70 concurrency”?

@fan-niu
Copy link
Author

fan-niu commented Nov 6, 2024

@MasterJH5574 Yes, every time 70 concurrency is used, the service will crash. Thanks for your reply, looking forward good news

@fan-niu
Copy link
Author

fan-niu commented Nov 11, 2024

@MasterJH5574 hi Is there any progress on this issue? thanks

@MasterJH5574
Copy link
Member

@fan-niu Sorry we haven't got enough bandwidth to work on it. We will try best as early as possible.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Confirmed bugs
Projects
None yet
Development

No branches or pull requests

2 participants