Backend/FastAPI

[FastAPI] 딥러닝 모델 서빙하기

mopil 2024. 1. 3. 17:11
반응형

pytorch아 tensorflow로 작성된 딥러닝 모델을 FastAPI로 서빙하는 방법에대해 공유한다.

 

# 서론

딥러닝 개발자가 어느정도 서버 지식을 갖고 있어서 모델 서빙까지 해주면 더할나위 없이 좋지만, 그렇지 않은 경우 서버 개발자가 이를 수행해야한다.

 

필자는 보통 모델 서빙 서버와 비즈니스 서버를 분리하지 않고 하나의 FastAPI 프레임워크로 구성하는데, 프로젝트 규모가 크다면 모델 서빙용 서버를 따로 분리하는 것도 방법이다.

 

 

# 딥러닝 개발자의 역할

필자는 Vision 관련 모델 (Shufflenet이나 resnet 등)만 연동해봐서, 이 기준으로 설명함을 미리 알린다.

 

1. 모델 학습시키기

딥러닝 개발자는 코랩이든 워크스테이션이든 일단 학습을 시킨다.

 

2. 모델 소스와 weight 파일 전달

학습이 완료된 모델 소스와 과 weight 파일 (.pt로 끝나는, 토치의 경우)을 서버 디렉토리에 위치시킨다. 서버 레포에 딥러닝 개발자가 직접 PR을 올려주면 좋다.

 

weight파일은 모델이 학습한 결과치를 담은 파일로, 이 파일을 읽어서 모델이 추론을 수행하기 때문에 모델 구동을 위해서 필수적인 파일이다.

 

(PR 예시)

https://github.com/project-sulsul/sulsul-backend/pull/10

 

딥러닝 파트 코드 추가 by Sangh0 · Pull Request #10 · project-sulsul/sulsul-backend

메인: inference, models/resnet

github.com

 

그리고 해당 모델을 사용할 수 있는 API (위 경우 inference 함수)를 제공해달라고 하면 된다.

 

def classify(
    img_url: str,
    weight_file_path: str = "ai/weights/resnet18_qat.pt",
    model_name: str = "resnet18",
    threshold: float = 0.5,
    quantization: str = "qat",
    num_classes: int = 39,
) -> ClassificationResultDto:
    # load model
    model = load_model(
        model_name=model_name,
        weight=weight_file_path,
        num_classes=num_classes,
        quantization=quantization,
    )

    # load image
    img, img_url = load_image(img_url=img_url)

    # inference
    result = inference(img, model, threshold)

    return ClassificationResultDto(foods=result["foods"], alcohols=result["alcohols"])

 

서버 개발자는 모델 관련 코드를 아예 몰라도 해당 classify 함수만 호출하면 된다.

 

 

모델 코드가 변경되거나 하면, 서버 레포에 PR을 올려달라는 식으로 요청해서 코드의 일관성을 유지하자.

(보통 코랩으로 작성하는 경우가 많아, 복사-붙혀넣기 하는 과정에서 코드가 손상되는 경우가 많은데, 꼭 딥러닝 개발자에게 더블체크를 부탁하자)

 

# 서버 개발자의 역할

이렇게 전달받은 모델 코드를 가지고 라우터 API를 만들면 된다.

 

https://github.com/project-sulsul/sulsul-backend/blob/main/api/routers/test_router.py

 

@router.post("/ai")
async def get_inference_from_image(image: UploadFile, model_name: AiModel, threshold: float = 0.5):
    url = upload_file_to_s3(image, "images")
    weight_file_path = f"ai/weights/{model_name.value}_qat.pt"
    return classify(url, weight_file_path=weight_file_path, model_name=model_name.value, threshold=threshold)

 

 

그리고 동일하게 배포를 해주면 끝

 

 

필자는 보통 ai 디렉토리를 따로 빼서, 여기서 모델 관련 코드를 모두 관리하게끔 구성한다.

 

반응형