자 이틀간 제 골머리를 썩히게 만들었던 axotl의 레이어 가중치 묶고 푸는것에 대한 문제 해결 방법 공유하도록 하겠습니다.

우선 어떤것이 문제냐?


deepspeed zero3 이놈이 문제였다.


문제인 이유: 우선 이놈으로 가지고 오게 되면, 우선 명시적인 텐서의 size가 0이 되어버리는 이상한 문제가 있다.


물론, 그대로 full fine-tuning이나, pre-train 하게 되면, 상관 없지만, 특히 내가 원하는 레이어만 얼리고 싶은데, 이놈을 기존의 방식대로 풀어주려면, 텐서사이즈가 0으로 torch에서 찍히니, 문제가 발생


나는 8번 레이어를 묶어주고 싶었는데, 8번 레이어중 하나의 shape을 찍어보면, 이렇게 보인다. 



model.layers.8.self_attn.q_proj.weight

torch.Size([0])



for param in model.parameters():

    param.requires_grad = False


for name, param in model.named_parameters():

    if any(pattern.match(name) for pattern in compiled_patterns):

        if is_main_process():

               LOG.debug(f"unfreezing {name}")

        param.requires_grad = True


기존 axotl의 freeze.py 의 코드이다. 

여기서 발생하는 문제가 무엇이냐면,  param.requires_grad = False는 문제가 없지만, size가 0인 텐서에 param.requires_grad = True를 해주면서 에러가 발생한다.


아마 지금 올라와 있는 solar 모델도 비슷한 이유로 lm_head와 embede_token 레이어만 풀어주는 것 같다.

그러면 lm_head와 embede_token은 문제가 없는가?


=>  이건 신기하게, deepspeed 영향을 안받는것 같다.

아래의 레이어를 찍어보면 이렇게 찍힌다.


lm_head.weight

torch.Size([32000, 4096]) 이렇게 말이다.


그래서 이것은 기존의 axotl방식으로 해도 문제가 밣생하지 않는다.


그러면 나는 내가 원하는 레이어들만 묶고 풀어주려면 어떻게 해야하나요?


방법은 간단하다.


이걸 기존의 방식으로 해결하려고 거의 25시간동안 디버깅을 했지만, 


풀수 있는 로직은 간단하였다.


반대로 풀어주면되는것.


즉 이미 모델을 불러오는 순간 require_grad=True인 상태이다.


그러니, 내가 원하는 레이어를 제외하면 require_grad=False로 만들어주면 아주 간단하게 해결이 된다...



즉 기존의 코드가 아닌

for param in model.parameters():

    param.requires_grad = False


for name, param in model.named_parameters():

    if any(pattern.match(name) for pattern in compiled_patterns):

        if is_main_process():

               LOG.debug(f"unfreezing {name}")

        param.requires_grad = True

        pass

    else:

        param.requires_grad = False


이렇게 반대로 해결해주면 문제는 간단하게 해결된다....


그나저나 진짜 마이크로 소프트 이시키들은 윈도우 만들듯이 라이브러리 만들어놔서 진짜 사람 잠도 못자게 하고 빡치게 만드네.....



해결방안 찾고자 라이브러리 직접 수정하고 별짓하다가 포기하고 그냥 1분만에 해결함....


이런 이슈가지고 어려워 하는 사람들 있던데 그냥 삽질말고 가장 간단한 방법으로 해결하시길...



* 추가

deepspeed로 훈련할때 bfloat16은 쓰면 안되는것 같습니다.

우선 loss가 explode합니다. fp32로 해야합니다.

weighted decay할때, 로스가 정확하게 계산 안된다는 이슈가 있네요...