- 링크:
https://arxiv.org/pdf/2409.12917
Background
- STaR (Self-Taught Reasoner)
- 방법론: LLM이 스스로 resoning chain을 생성 정답을 낸 reasoning trace만 (룰기반채점) 모아서 SFT
- Distibution shift: base 모델의 오류는 고치는데, 새로 학습한 모델은 분포가 달라 또 못 맞춤
- Behavior collapse: 올바른 chain 데이터로만 학습하면서 모델이 점차 첫 시도에서는 맞는 답을 생성 -> 자신의 실수를 찾아 고치는 self-correction 능력을 배우지 않게 됨
- -> Distribution shift와 Behavior collapse를 개선하자
Methods

- SCoRe
- oracle feedback 없이 문제에 대한 응답을 생성하고 에러를 수정
- SFT
- offline model로 entirely self-generated data를 생성하여 SFT
- 1. prompt를 입력으로 넣어 문제에 대한 응답을 생성함
- 2. 이 original answer과 instruction으로 모델을 한 번 finetuning
- cross entropy loss와 KL-divergence loss를 합하여 total loss로 사용
- 어떻게 SFT로 mismatch를 해결해? -> 한 모델이 만든 original answer과 advanced answer를 하나의 세트로 묶어서 사용함으로써 기존의 train!=test mismatch를 해결함
- 단 distribution shift를 완전히 해결하지는 못함 -> offline data 사용이 근본적인 원인
- 실제 correction 능력 향상은 크지 않음
- reward는 answer에 대한 exact match로 룰기반으로 계산함
- progress reward
- a bonus 𝑏̂(𝒚2 ∣𝒚1 , 𝒚 ∗ ) ∶= 𝛼 ⋅ (𝑟̂(𝒚2 , 𝒚 ∗ ) − 𝑟̂(𝒚1 , 𝒚 ∗ )),
- 그냥 학습시키면 첫번째 단계에서 좋은 답변을 생성하고 두번째 단계에서는 교정 없이 그대로 답변을 사용하는 collapse가 발생할 수 있음 -> 보너스 리워드로 해결
- original answer이 advanced answer에서 정답으로 고쳐지면 더 많은 리워드를 받도록
- 다른 방법론과의 차이점
- Multi-turn
- online reinforce learning
Experiment & Analysis
- Figure 1
- MATH task에서 성능 개선을 보임

- Table 1
- Base 모델에 비해 self-correction이 크게 개선됨
- ㅅ (i->c): 틀렸던 문제를 고친 비율
- ㅅ (c->i): 맞았던 문제를 틀린 비율
- i->c가 크게 증가하고, c->i가 낮아서 이미 맞춘 문제는 오답으로 바꾸지 않고, 틀린 문제는 정답으로 잘 푼다는 것을 증명
- -> 이걸 반박하는게 S2R

Result
- 추론 task에서 성능 개선을 보임
Limitation
- 추론 task로 한정적임