상세 컨텐츠

본문 제목

pytorch - Linear layer의 datatype

프로그래밍/딥러닝

by 릿카。 2024. 4. 6. 10:06

본문

Linear layer를 가진 모델로 학습을 시키려다보면 간혹 ‘mat1과 mat2의 데이터타입이 달라서 어쩌구..’라는 오류가 뜬다.

보통 입력데이터와 레이어가 사용하는 데이터타입이 달라서 발생하는 일인데, Linear layer는 기본적으로 torch.float32를 사용한다.

Linear뿐만아니라 많은 Layer들이 torch.float32를 지원하니까 알아두자.

https://pytorch.org/docs/stable/generated/torch.nn.Linear.html


예컨대 입력데이터가 torch.long이면 Linear layer는 계산을 진행할 수 없고, 그래서 오류남.

해결법은 쉬운데

그냥 입력데이터를 torch.float32 타입으로 바꿔주면된다. 예컨대 a = torch.tensor(data, dtype=torch.float32)

관련글 더보기

댓글 영역