triton Inference server에 모델들을 경량화 해서 올리다가
text recognition 모델을 onnx로 변환하는 과정에서 겪은 어려움이 있었어서 기록해두고자 한다.
Deep Text Recognition Benchmark 모델 구조
Deep Text Recognition Benchmark는 유연하게 모델을 구성할 수 있는 OCR(Open Character Recognition) 프레임워크로, 여러 단계에서 다양한 옵션을 조합하여 최적의 성능을 내는 구조를 설계할 수 있다.
1️⃣ Transformation (TPS / None)
Transformation 단계는 입력 이미지를 모델이 학습하기 좋은 형태로 변환하는 역할
- TPS (Thin Plate Spline Transformation)
- 왜곡된 이미지(예: 곡선 텍스트)를 직선화하여 모델이 더 쉽게 읽을 수 있도록 도움을 준다.
- 이미지에서 특정 지점을 기준으로 변형하여 텍스트를 정렬하는 방식.
- 실제 장면에서 찍힌 이미지나 텍스트의 각도가 불규칙할 때 유용하다.
- None
- 변환 과정을 생략합니다. 이미 텍스트가 잘 정렬되어 있는 상황이라면 이 옵션이 더 빠르기때문에 효율을 따져서 선택이 필요함.
2️⃣ Feature Extraction (ResNet / VGG / RCNN)
Feature Extraction 단계는 입력 이미지에서 중요한 시각적 특징(패턴, 모양 등)을 추출하는 부분.
- ResNet (Residual Network)
- 텍스트와 배경을 구분하는 데 강력한 성능을 보이며, 일반적으로 성능과 속도 모두에서 균형이 좋다고 합니다.
- VGG (Visual Geometry Group Network)
- 층이 깊고 간단한 구조로 고해상도 이미지를 처리하는 데 적합.
- 상대적으로 계산량이 많아 속도는 느릴 수 있지만, 작은 이미지에서 강력한 성능을 보여준다고 한다.
- RCNN (Recurrent Convolutional Neural Network)
- 순환 신경망(RNN)과 CNN을 결합한 구조로, 연속된 텍스트나 특징을 학습하는 데 적합.
3️⃣ Sequence Modeling (BiLSTM / None)
Sequence Modeling 단계는 이미지에서 추출된 특징을 시퀀스(순서 정보)로 변환하여 텍스트의 순서를 학습합니다.
- BiLSTM (Bidirectional Long Short-Term Memory)
- 입력 데이터를 양방향(앞뒤)으로 처리하여 텍스트의 문맥 정보를 학습합니다.
- 텍스트의 앞뒤 관계를 모두 고려할 수 있어 긴 문장에서 효과적입니다.
- None
- 위 단계 생략. 속도를 위해 효율적인 선택이 필요함.
4️⃣ Prediction (CTC / Attention)
Prediction 단계는 최종적으로 텍스트를 예측하는 역할.
- CTC (Connectionist Temporal Classification)
- 텍스트의 정렬 정보(문자 위치)가 없는 경우에 적합.
- 각 문자에 대해 독립적으로 예측하며, 동일한 문자의 중복을 제거하고 공백(Blank)을 처리해 최종 결과를 생성.
- 주로 이미지의 텍스트가 고정된 순서로 나올 때 사용됩니다.
- Attention
- 입력 이미지의 각 부분에 가중치를 부여해 중요한 정보를 선택적으로 집중합니다.
- 복잡한 배경이나 불규칙한 텍스트에서도 강력한 성능을 발휘함.
- CTC보다 더 정교한 예측이 가능하지만, 계산량이 많아질 수 있다.
맨 처음 text recognition 모델로 학습을 시켰던 모델 구조는
https://github.com/clovaai/deep-text-recognition-benchmark
GitHub - clovaai/deep-text-recognition-benchmark: Text recognition (optical character recognition) with deep learning methods, I
Text recognition (optical character recognition) with deep learning methods, ICCV 2019 - clovaai/deep-text-recognition-benchmark
github.com
깃허브를 바탕으로 TRBA (TPS-ResNet-BiLSTM-Attn) 로 학습을 시켰다. 결과는 (pre-trained 모델을 사용해서 fine-tunning을 시켰을 때) 프로젝트에 사용할 만큼의 충분한 정확도와 성능을 보였다. 하지만 엣지 컴퓨터에서의 원활한 배포와 관리를 위해 triton server에서 경량화를 진행하려 했고, onnx 변환을 하는 과정에서 문제가 생겼다.
우선 TPS 모듈을 ONNX로 변환하는 과정을 실패하고 찾아본 결과
https://github.com/clovaai/deep-text-recognition-benchmark/issues/191
Convert ONNX: Failed to export an ONNX attribute, since it's not constant, please try to make things (e.g., kernel size) static
I using torch 1.3.1, torchvision 0.4.0 'Failed to export an ONNX attribute, since it's not constant, please try to make things (e.g., kernel size) static if possible' at torch.onnx.export()
github.com
TPS는 동적 연산과 맞춤형 구현으로 ONNX 변환이 어렵다는 결론을 내렸다.
(GPT에게 물어보고 "TPS는 동적 연산(Grid Sampling, Dynamic Control Points)과 ONNX 미지원 연산(torch.grid_sample, torch.inverse) 때문에 변환이 어렵습니다." 라는 답변을 얻음)
그래서 필요시 TPS 연산 부분은 외부에서 진행하고 우선 (None- ResNet-BiLSTM-CTC) 로 재학습을 진행했다.
ONNX 변환 과정
ONNX 변환을 위한 모델 구조를 업데이트
class Model(nn.Module):
def __init__(self, opt):
super(Model, self).__init__()
self.opt = opt
self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction,
'Seq': opt.SequenceModeling, 'Pred': opt.Prediction}
""" Transformation """
if opt.Transformation == 'TPS':
self.Transformation = TPS_SpatialTransformerNetwork(
F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel)
else:
print('No Transformation module specified')
""" FeatureExtraction """
if opt.FeatureExtraction == 'VGG':
self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'RCNN':
self.FeatureExtraction = RCNN_FeatureExtractor(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'ResNet':
self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel)
else:
raise Exception('No FeatureExtraction module specified')
self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512
self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1
""" Sequence modeling"""
if opt.SequenceModeling == 'BiLSTM':
self.SequenceModeling = nn.Sequential(
BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size),
BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size))
self.SequenceModeling_output = opt.hidden_size
else:
print('No SequenceModeling module specified')
self.SequenceModeling_output = self.FeatureExtraction_output
""" Prediction """
if opt.Prediction == 'CTC':
self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class)
elif opt.Prediction == 'Attn':
self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class)
else:
raise Exception('Prediction is neither CTC or Attn')
def forward(self, input, is_train=True):
""" Feature extraction stage """
visual_feature = self.FeatureExtraction(input)
visual_feature = visual_feature.permute(0, 3, 1, 2)
visual_feature = visual_feature.squeeze(3)
""" Prediction stage """
prediction = self.Prediction(visual_feature.contiguous())
return prediction
# 학습 옵션들 써주기
class InputData(BaseModel):
image_folder: str
saved_model: str
batch_max_length: int = 25
imgH: int = 32
imgW: int = 100
character: str = '0123456789.bcdefghijklmnopqrstuvwxyz'
Transformation: str = None
FeatureExtraction: str = 'ResNet'
SequenceModeling: str = 'BiLSTM'
Prediction:str = 'CTC'
num_fiducal: int = 20
input_channel: int = 1
output_channel: int = 512
hidden_size: int = 256
num_class:int = 37
data = {'FeatureExtraction' : 'ResNet', 'SequenceModeling' : 'BiLSTM', 'image_folder': 'demo_image/', 'saved_model' : 'pth모델 path'}
opt = InputData(**data)
model = Model(opt)
ONNX 변환
device = 'cuda:0'
model.load_state_dict(fix_model_state_dict(torch.load(opt.saved_model, map_location=device)))
input = torch.randn(1, 1, 32, 100)
model.eval()
torch.onnx.export(model,
input,
"None-ResNet-None-CTC.onnx",
export_params=True,
opset_version=11,
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'],
output_names = ['output'],
dynamic_axes={'input' : {0 : 'batch_size',},
'output' : {1 : 'seq_length'}})
print("Model converted succesfully")
* 혹시 ONNX 변환 시 모델의 상태 딕셔너리에서 "module." 접두어가 붙어서 keyError가 나는 경우 아래 함수를 사용하시면 됩니다.
def fix_model_state_dict(state_dict):
new_state_dict = OrderedDict()
for key, value in state_dict.items():
if key.startswith('module.'):
key = key[7:]
new_state_dict[key] = value
return new_state_dict
결과적으로 위에 구조는 onnx 변환을 성공적으로 하지 못했다.
onnx 변환 시 아래와 같은 경고문이 떴는데 LSTM을 ONNX 변환할 때는 배치 사이즈를 1로만 고정하거나 모델 input값을 초기화하는 과정을 정의해줘야 한다는 경고였다.
UserWarning: Exporting a model to ONNX with a batch_size other than 1, with a variable length with LSTM can cause an error when running the ONNX model with a different batch size. Make sure to save the model with a batch size of 1, or define the initial states (h0/c0) as inputs of the model.
ONNX 모델 추론해보기
import onnxruntime as ort
import numpy as np
import cv2
import torch
# 이미지 전처리
image_path = 'image1.png'
orig_image = cv2.imread(image_path)
image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2GRAY)
image = image/127.5 - 1.0
image = cv2.resize(image,(100, 32),interpolation=cv2.INTER_CUBIC)
image = np.expand_dims(image, 0)
image = np.expand_dims(image, 0)
image = np.float32(image)
torch_input = torch.from_numpy(image)
ort_session = ort.InferenceSession("모델.onnx")
def to_numpy(tensor):
print(tensor)
return tensor.detach().cpu().numpy()
# # compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: image}
ort_outs = ort_session.run(None, ort_inputs)
outputs = ort_outs[0]
character_map = " 0123456789.bcdefghijklmnopqrstuvwxy"
def ctc_decode(preds, character_map):
pred_index = np.argmax(preds, axis=2) # [batch_size, seq_length]
BLANK = 0
texts = []
# 첫 번째 예시(batch_size=1)만 처리
output = pred_index[0, :]
characters = []
for i in range(output.shape[0]): # seq_length
if output[i] != BLANK and (i == 0 or output[i - 1] != output[i]): # Blank or duplicate char
characters.append(character_map[output[i]])
text = ''.join(characters)
return text
output_text = ctc_decode(ort_outs[0], character_map)
print(f"디코딩된 텍스트: {output_text}")
BiLSTM이 들어간 모델은 00680을 00080으로 추론을 하는 등 결과 성능이 나빠진 모습을 보였다.
위에 경고문을 참고해서 Sequence Modeling 구조를 변경하고 변경한 구조에 맞춰 lstm에 초기 상태를 함께 넣어주는 시도를 해보았지만 좋은 결과를 보지 못했다....
# 실패한 시도....
h0_1 = torch.zeros(2, batch_size, hidden_size) # 첫 번째 BiLSTM의 초기 상태
c0_1 = torch.zeros(2, batch_size, hidden_size)
h0_2 = torch.zeros(2, batch_size, hidden_size) # 두 번째 BiLSTM의 초기 상태
c0_2 = torch.zeros(2, batch_size, hidden_size)
torch.onnx.export(model,
(input, h0_1, c0_1, h0_2, c0_2),
"LSTM.onnx",
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=["input", "h0_1", "c0_1", "h0_2", "c0_2"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size", 1: "seq_length"},
"h0_1": {1: "batch_size"},
"c0_1": {1: "batch_size"},
"h0_2": {1: "batch_size"},
"c0_2": {1: "batch_size"},
"output": {0: "batch_size", 1: "seq_length"}})
결국 batch_size를 다이나믹하게 주기 위해서 우선 BiLSTM을 제외하고 None-ResNet-None-CTC 형식으로 재학습 후 onnx 변환을 한 결과 성능이 떨어지지 않고 onnx 변환에 성공했다!
프로젝트 진행과는 별개로 lstm을 onnx로 변환할 때 주의해야하는 부분을 좀 더 공부해서 lstm 구조까지 포함된 모델로 성공적인 변환을 해볼 예정...!!
'AI' 카테고리의 다른 글
[논문 리뷰] BPE Tokenizer (1) | 2024.02.04 |
---|---|
KLUE (0) | 2024.02.02 |
[논문 리뷰]attention 매커니즘 (0) | 2024.01.31 |
워드 임베딩 시각화 (0) | 2024.01.31 |
임베딩 (0) | 2024.01.24 |