Skip to content

Latest commit

 

History

History
85 lines (82 loc) · 5.23 KB

MoE.md

File metadata and controls

85 lines (82 loc) · 5.23 KB

Mixture of Experts

  1. Vanilla implementation

  2. max_model_len: max_possition_embeddings

  3. Graph Capture Size

    1. max_seq_len_to_capture : name terminology
    2. _verify_cuda_graph:
      self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, self.max_model_len)
    3. _get_graph_batch_size : Returns the padded batch size given actual batch size. Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT, 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
  4. Graph usage condition

    1. _use_captured_graph : link
    2. print log for Mixtral-8x22B
       --input-len 8192 --output-len 3 --batch-size 32  
      max_seq_len_to_capture is set to 8192+256=8448
               decode_only: False && not enforce_eager: False
               ,  batch_size: 65536 <= _BATCH_SIZES_TO_CAPTURE: 8192
               ,  max_decode_seq_len: 0, max_encoder_seq_len: 0 <=  max_seq_len_to_capture: 8448
                  batch_size: 65536 <= max_batchsize_to_capture: 256
                  --> result (_use_captured_graph) = False
               decode_only: False && not enforce_eager: False
               ,  batch_size: 65536 <= _BATCH_SIZES_TO_CAPTURE: 8192
               ,  max_decode_seq_len: 0, max_encoder_seq_len: 0 <=  max_seq_len_to_capture: 8448
                  batch_size: 65536 <= max_batchsize_to_capture: 256
                  --> result (_use_captured_graph) = False
               decode_only: False && not enforce_eager: False
               ,  batch_size: 65536 <= _BATCH_SIZES_TO_CAPTURE: 8192
               ,  max_decode_seq_len: 0, max_encoder_seq_len: 0 <=  max_seq_len_to_capture: 8448
                  batch_size: 65536 <= max_batchsize_to_capture: 256
                  --> result (_use_captured_graph) = False
               decode_only: False && not enforce_eager: False
               ,  batch_size: 65536 <= _BATCH_SIZES_TO_CAPTURE: 8192
               ,  max_decode_seq_len: 0, max_encoder_seq_len: 0 <=  max_seq_len_to_capture: 8448
                  batch_size: 65536 <= max_batchsize_to_capture: 256
                  --> result (_use_captured_graph) = False
               decode_only: True && not enforce_eager: False
               ,  batch_size: 32 <= _BATCH_SIZES_TO_CAPTURE: 8192
               ,  max_decode_seq_len: **8193**, max_encoder_seq_len: 0 <=  max_seq_len_to_capture: 8448
                  batch_size: 32 <= max_batchsize_to_capture: 256
                  --> result (_use_captured_graph) = True
               decode_only: True && not enforce_eager: False
               ,  batch_size: 32 <= _BATCH_SIZES_TO_CAPTURE: 8192
               ,  max_decode_seq_len: **8194**, max_encoder_seq_len: 0 <=  max_seq_len_to_capture: 8448
                  batch_size: 32 <= max_batchsize_to_capture: 256
                  --> result (_use_captured_graph) = True
               decode_only: True && not enforce_eager: False
               ,  batch_size: 32 <= _BATCH_SIZES_TO_CAPTURE: 8192
               ,  max_decode_seq_len: **8195**, max_encoder_seq_len: 0 <=  max_seq_len_to_capture: 8448
                  batch_size: 32 <= max_batchsize_to_capture: 256
                  --> result (_use_captured_graph) = True
           
  5. Chunked prefill gets enabled automatically for model_len > 32k: link

  6. Decode latency IS affected by prefill length.

    • BS=240, output=200
    Input-len Decode latency
    512 63ms
    1024 66ms
    2048 69ms
    Reason
     kv cache size is bigger for larger context length. Hence, paged_attn kernel takes more time! 
                 BS=240 | _paged_attn kernel time: In512 : 41us  --vs-- In2048 : 121us
             
  7. Config file names to avoid confusion:

        Mi300 file names:  
             AMD_Instinct_MI300X.json
        Mi308 file names:
             AMD_Instinct_MI300X_OAM.json
             AMD_Instinct_MI308X_OAM.json
             AMD_Radeon_Graphics.json