aiai:: onnx 파일 확장자 설명

 

ONNX

ONNX(Open Neural Network Exchange)는 기계 학습 모델을 위한 개방형 파일 형식으로, 다양한 프레임워크 간에 모델을 쉽게 이동할 수 있게 해주는 중간 표현(IR) 형식입니다.

주요 특징:

  • Microsoft와 Facebook(현 Meta)가 2017년에 공동 개발했습니다
  • PyTorch, TensorFlow, scikit-learn 등 다양한 프레임워크 간의 모델 호환성 제공
  • 다양한 하드웨어와 플랫폼에서 모델을 실행할 수 있게 해줌
  • 모델 구조와 매개변수를 모두 저장

ONNX 파일이 유용한 이유:

  • 프레임워크 간 이동성 - 한 프레임워크에서 모델을 훈련하고 다른 프레임워크에서 추론 가능
  • 최적화 도구 - ONNX Runtime 같은 도구로 추론 성능 향상 가능
  • 하드웨어 최적화 - 다양한 하드웨어(CPU, GPU, NPU 등)에 맞게 최적화 가능

일반적인 사용 사례:

  • 모델 배포 - 클라우드, 엣지 디바이스, 모바일 등 다양한 환경에 배포
  • 프레임워크 간 모델 공유
  • 모델 최적화 및 양자화

파일 형식은 본질적으로 이진 형식이며 Protocol Buffers를 사용하여 직렬화됩니다.

 

 

 

ONNX 포맷사용방법

1. 모델 변환하기

PyTorch에서 ONNX로 변환:

 

import torch
import torchvision

# 모델 불러오기
model = torchvision.models.resnet50(pretrained=True)
model.eval()

# 더미 입력 생성 (입력 형태 지정용)
dummy_input = torch.randn(1, 3, 224, 224)

# ONNX로 내보내기
torch.onnx.export(model,               # 모델
                  dummy_input,         # 모델 입력의 예시
                  "resnet50.onnx",     # 저장될 파일 이름
                  export_params=True,  # 모델 파라미터 저장
                  opset_version=11,    # ONNX 버전
                  do_constant_folding=True)  # 상수 폴딩 최적화

 

 

TensorFlow에서 ONNX로 변환:

import tf2onnx
import tensorflow as tf

# TensorFlow 모델 불러오기
model = tf.keras.applications.MobileNetV2()

# ONNX로 변환
model_proto, _ = tf2onnx.convert.from_keras(model)

# 파일로 저장
with open("mobilenet.onnx", "wb") as f:
    f.write(model_proto.SerializeToString())

 

 

 

ONNX Runtime으로 추론하기

import numpy as np
import onnxruntime

# 추론 세션 생성
session = onnxruntime.InferenceSession("model.onnx")

# 입력 이름 확인
input_name = session.get_inputs()[0].name

# 임의 입력 데이터 생성
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)

# 추론 실행
result = session.run(None, {input_name: input_data})

 

 

 

 

TensorFlow에서 ONNX 모델 사용:

import onnx
import onnx_tf
import tensorflow as tf

# ONNX 모델 불러오기
onnx_model = onnx.load("model.onnx")

# TensorFlow 모델로 변환
tf_model = onnx_tf.backend.prepare(onnx_model)

# 추론 실행
output = tf_model.run(input_data)

 

 

 

 

 

 

_

반응형