개요
이 글은 스퀴즈비츠의 김태수님이 발표한 내용으로 두 논문을 정리하였다.
LLM에 토큰을 하나씩 생성할 때마다 굉장히 많은 weight를 불러와야 한다. 그래서 DRAM bandwidth가 문제가 된다. Autoregressive 방식이 GPU를 완전하 활용하지 못하는 문제가 발생한다. 이를 해결하기 위한 방법 중 하나로 Speculative Decoding이 있다. Speculative Decoding은 1개의 프롬프트를 1 배치로 처리하는 것이 아니라, 예측한 여러 토큰들을 동시에 재입력하여 병렬 처리하는 기술이다. 따라서 모델은 여러 입력 문장을 배치 단위로 처리한다.
Speculative Decoding
이 논문은 Draft, Verification을 단순하게 구현하여 최적의 토큰을 찾는다. 이때 적절한 토큰이 아니면 물러나는데 이 rejection을 잘하는 방법이 중요하다. 이 논문은 computational resource 활용성을 높이기 위해, Speculative Sampling 방법을 제안한다.
- Speculative Sampling:
- 더 작은 모델(Mq)을 사용하여 여러 후보 토큰을 생성하고, 이를 큰 모델(Mp)을 사용하여 병렬로 검증한다.
- 후보가 큰 모델에 의해 검증되면 최종 출력으로 선택되고, 그렇지 않으면 조정된 분포에서 재샘플링한다.
- Draft와 Verification 단계:
- Draft 단계: 더 작은 효율적인 모델(Mq)을 사용하여 여러 토큰을 병렬로 예측한다.
- Verification 단계: 큰 목표 모델(Mp)을 사용하여 Draft 단계에서 생성된 예측을 검증하고, 필요한 경우 수정한다.
장점:
- 추론 시간 향상
- 1 배치에서 하드웨어 유틸라지션이 더 좋음 단점:
- 드래프트 모델을 어떻게 학습하고 선택할 것인가?
- 서로 다른 모델의 얼라이먼트를 어떻게 잡아줄 것인가?
Medusa
이 논문은 Original Model에서 트랜스포머 레이어의 last hidden을 추출하여 Medusa Head를 학습하여 적합한 드래프팅 모델을 만들수 있다. 그리고 서로 다른 이종 모델을 잘 컴바인하여 호스팅할 수 있다.
- Training Strategies
- Medusa-1은 기존 모델을 고정한 채 추가적인 디코딩 헤드를 조정하고, Medusa-2는 백본 모델과 디코딩 헤드를 함께 조정한다.
- Tree Attention
- Medusa Head를 통해 생성한 토큰 간 self attention을 연산한다.
- 이를 통해 토큰의 컨텍스트를 유치한 채, 여러 후보를 처리할 수 있다.
장점:
- 추가 학습을 통해 제약 조건에 맞추어 적용이 가능하다.
- 여러 디코딩 헤드를 병렬로 예측하여 추론 속도를 향상시킬 수 있다.
단점:
- 최적의 트리 구조를 찾기 위한 오버헤드가 필요하다.
레퍼런스
[1] Fast Inference from Transformers via Speculative Decoding
[2] Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads