Linear layer를 가진 모델로 학습을 시키려다보면 간혹 ‘mat1과 mat2의 데이터타입이 달라서 어쩌구..’라는 오류가 뜬다.
보통 입력데이터와 레이어가 사용하는 데이터타입이 달라서 발생하는 일인데, Linear layer는 기본적으로 torch.float32를 사용한다.
Linear뿐만아니라 많은 Layer들이 torch.float32를 지원하니까 알아두자.
예컨대 입력데이터가 torch.long이면 Linear layer는 계산을 진행할 수 없고, 그래서 오류남.
해결법은 쉬운데
그냥 입력데이터를 torch.float32 타입으로 바꿔주면된다. 예컨대 a = torch.tensor(data, dtype=torch.float32)
CNN parameter 계산 관련 참고할만한 블로그 (1) | 2024.09.04 |
---|---|
(pytorch) pytorch-tensor와 numpy-ndarray 간에 변환하기 (0) | 2024.09.04 |
(pytorch) 텐서를 CPU와 GPU 간에 이동시키기 (0) | 2024.09.04 |
PyTorch (1) Dataset, DataLoader (0) | 2024.07.29 |
댓글 영역