refactor: simplify downloaded models

This commit is contained in:
Carson M
2023-01-17 10:09:15 -06:00
parent dbb1e00e9c
commit b49ca6f94d
10 changed files with 56 additions and 207 deletions

View File

@@ -1,83 +1,6 @@
#[cfg(feature = "fetch-models")]
use std::{
fs, io,
path::{Path, PathBuf},
time::Duration
};
#[cfg(feature = "fetch-models")]
use tracing::info;
#[cfg(feature = "fetch-models")]
use crate::error::{OrtDownloadError, OrtResult};
pub mod language;
pub mod vision;
/// Available pre-trained models to download from the [ONNX Model Zoo](https://github.com/onnx/models).
#[derive(Debug, Clone)]
pub enum OnnxModel {
/// Computer vision models
Vision(vision::Vision),
/// Language models
Language(language::Language)
}
trait ModelUrl {
pub trait ModelUrl {
fn fetch_url(&self) -> &'static str;
}
impl ModelUrl for OnnxModel {
fn fetch_url(&self) -> &'static str {
match self {
OnnxModel::Vision(model) => model.fetch_url(),
OnnxModel::Language(model) => model.fetch_url()
}
}
}
impl OnnxModel {
#[cfg(feature = "fetch-models")]
#[tracing::instrument]
pub(crate) fn download_to<P>(&self, download_dir: P) -> OrtResult<PathBuf>
where
P: AsRef<Path> + std::fmt::Debug
{
let url = self.fetch_url();
let model_filename = PathBuf::from(url.split('/').last().unwrap());
let model_filepath = download_dir.as_ref().join(model_filename);
if model_filepath.exists() {
info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download");
Ok(model_filepath)
} else {
info!(model_filepath = format!("{}", model_filepath.display()).as_str(), url = format!("{:?}", url).as_str(), "Downloading model");
let resp = ureq::get(url)
.timeout(Duration::from_secs(180))
.call()
.map_err(Box::new)
.map_err(OrtDownloadError::FetchError)?;
assert!(resp.has("Content-Length"));
let len = resp.header("Content-Length").and_then(|s| s.parse::<usize>().ok()).unwrap();
info!(len, "Downloading {} bytes", len);
let mut reader = resp.into_reader();
let f = fs::File::create(&model_filepath).unwrap();
let mut writer = io::BufWriter::new(f);
let bytes_io_count = io::copy(&mut reader, &mut writer).map_err(OrtDownloadError::IoError)?;
if bytes_io_count == len as u64 {
Ok(model_filepath)
} else {
Err(OrtDownloadError::CopyError {
expected: len as u64,
io: bytes_io_count
}
.into())
}
}
}
}

View File

@@ -1,18 +1,3 @@
use super::ModelUrl;
pub mod machine_comprehension;
pub use machine_comprehension::MachineComprehension;
#[derive(Debug, Clone)]
pub enum Language {
MachineComprehension(MachineComprehension)
}
impl ModelUrl for Language {
fn fetch_url(&self) -> &'static str {
match self {
Language::MachineComprehension(v) => v.fetch_url()
}
}
}

View File

@@ -1,6 +1,6 @@
#![allow(clippy::upper_case_acronyms)]
use crate::download::{language::Language, ModelUrl, OnnxModel};
use crate::download::ModelUrl;
/// Machine comprehension models.
///
@@ -61,21 +61,3 @@ impl ModelUrl for GPT2 {
}
}
}
impl From<MachineComprehension> for OnnxModel {
fn from(model: MachineComprehension) -> Self {
OnnxModel::Language(Language::MachineComprehension(model))
}
}
impl From<RoBERTa> for OnnxModel {
fn from(model: RoBERTa) -> Self {
OnnxModel::Language(Language::MachineComprehension(MachineComprehension::RoBERTa(model)))
}
}
impl From<GPT2> for OnnxModel {
fn from(model: GPT2) -> Self {
OnnxModel::Language(Language::MachineComprehension(MachineComprehension::GPT2(model)))
}
}

View File

@@ -1,5 +1,3 @@
use super::ModelUrl;
pub mod body_face_gesture_analysis;
pub mod domain_based_image_classification;
pub mod image_classification;
@@ -11,24 +9,3 @@ pub use domain_based_image_classification::DomainBasedImageClassification;
pub use image_classification::ImageClassification;
pub use image_manipulation::ImageManipulation;
pub use object_detection_image_segmentation::ObjectDetectionImageSegmentation;
#[derive(Debug, Clone)]
pub enum Vision {
BodyFaceGestureAnalysis(BodyFaceGestureAnalysis),
DomainBasedImageClassification(DomainBasedImageClassification),
ImageClassification(ImageClassification),
ImageManipulation(ImageManipulation),
ObjectDetectionImageSegmentation(ObjectDetectionImageSegmentation)
}
impl ModelUrl for Vision {
fn fetch_url(&self) -> &'static str {
match self {
Vision::DomainBasedImageClassification(v) => v.fetch_url(),
Vision::ImageClassification(v) => v.fetch_url(),
Vision::ImageManipulation(v) => v.fetch_url(),
Vision::ObjectDetectionImageSegmentation(v) => v.fetch_url(),
Vision::BodyFaceGestureAnalysis(v) => v.fetch_url()
}
}
}

View File

@@ -1,4 +1,4 @@
use crate::download::{vision::Vision, ModelUrl, OnnxModel};
use crate::download::ModelUrl;
#[derive(Debug, Clone)]
pub enum BodyFaceGestureAnalysis {
@@ -19,9 +19,3 @@ impl ModelUrl for BodyFaceGestureAnalysis {
}
}
}
impl From<BodyFaceGestureAnalysis> for OnnxModel {
fn from(model: BodyFaceGestureAnalysis) -> Self {
OnnxModel::Vision(Vision::BodyFaceGestureAnalysis(model))
}
}

View File

@@ -1,4 +1,4 @@
use crate::download::{vision::Vision, ModelUrl, OnnxModel};
use crate::download::ModelUrl;
#[derive(Debug, Clone)]
pub enum DomainBasedImageClassification {
@@ -13,9 +13,3 @@ impl ModelUrl for DomainBasedImageClassification {
}
}
}
impl From<DomainBasedImageClassification> for OnnxModel {
fn from(model: DomainBasedImageClassification) -> Self {
OnnxModel::Vision(Vision::DomainBasedImageClassification(model))
}
}

View File

@@ -1,6 +1,6 @@
#![allow(clippy::upper_case_acronyms)]
use crate::download::{vision::Vision, ModelUrl, OnnxModel};
use crate::download::ModelUrl;
#[derive(Debug, Clone)]
pub enum ImageClassification {
@@ -204,33 +204,3 @@ impl ModelUrl for ShuffleNetVersion {
}
}
}
impl From<ImageClassification> for OnnxModel {
fn from(model: ImageClassification) -> Self {
OnnxModel::Vision(Vision::ImageClassification(model))
}
}
impl From<ResNet> for OnnxModel {
fn from(variant: ResNet) -> Self {
OnnxModel::Vision(Vision::ImageClassification(ImageClassification::ResNet(variant)))
}
}
impl From<Vgg> for OnnxModel {
fn from(variant: Vgg) -> Self {
OnnxModel::Vision(Vision::ImageClassification(ImageClassification::Vgg(variant)))
}
}
impl From<InceptionVersion> for OnnxModel {
fn from(variant: InceptionVersion) -> Self {
OnnxModel::Vision(Vision::ImageClassification(ImageClassification::Inception(variant)))
}
}
impl From<ShuffleNetVersion> for OnnxModel {
fn from(variant: ShuffleNetVersion) -> Self {
OnnxModel::Vision(Vision::ImageClassification(ImageClassification::ShuffleNet(variant)))
}
}

View File

@@ -1,4 +1,4 @@
use crate::download::{vision::Vision, ModelUrl, OnnxModel};
use crate::download::ModelUrl;
/// Image Manipulation
///
@@ -54,15 +54,3 @@ impl ModelUrl for FastNeuralStyleTransferStyle {
}
}
}
impl From<ImageManipulation> for OnnxModel {
fn from(model: ImageManipulation) -> Self {
OnnxModel::Vision(Vision::ImageManipulation(model))
}
}
impl From<FastNeuralStyleTransferStyle> for OnnxModel {
fn from(style: FastNeuralStyleTransferStyle) -> Self {
OnnxModel::Vision(Vision::ImageManipulation(ImageManipulation::FastNeuralStyleTransfer(style)))
}
}

View File

@@ -1,6 +1,6 @@
#![allow(clippy::upper_case_acronyms)]
use crate::download::{vision::Vision, ModelUrl, OnnxModel};
use crate::download::ModelUrl;
/// Object Detection & Image Segmentation
///
@@ -88,9 +88,3 @@ impl ModelUrl for ObjectDetectionImageSegmentation {
}
}
}
impl From<ObjectDetectionImageSegmentation> for OnnxModel {
fn from(model: ObjectDetectionImageSegmentation) -> Self {
OnnxModel::Vision(Vision::ObjectDetectionImageSegmentation(model))
}
}

View File

@@ -1,11 +1,11 @@
#![allow(clippy::tabs_in_doc_comments)]
#[cfg(feature = "fetch-models")]
use std::env;
#[cfg(not(target_family = "windows"))]
use std::os::unix::ffi::OsStrExt;
#[cfg(target_family = "windows")]
use std::os::windows::ffi::OsStrExt;
#[cfg(feature = "fetch-models")]
use std::{env, path::PathBuf, time::Duration};
use std::{
ffi::CString,
fmt::{self, Debug},
@@ -33,7 +33,7 @@ use super::{
AllocatorType, GraphOptimizationLevel, MemType
};
#[cfg(feature = "fetch-models")]
use super::{download::OnnxModel, error::OrtDownloadError};
use super::{download::ModelUrl, error::OrtDownloadError};
/// Type used to create a session using the _builder pattern_.
///
@@ -266,18 +266,60 @@ impl SessionBuilder {
#[cfg(feature = "fetch-models")]
pub fn with_model_downloaded<M>(self, model: M) -> OrtResult<Session>
where
M: Into<OnnxModel>
M: ModelUrl
{
self.with_model_downloaded_monomorphized(model.into())
self.with_model_downloaded_monomorphized(model.fetch_url())
}
#[cfg(feature = "fetch-models")]
fn with_model_downloaded_monomorphized(self, model: OnnxModel) -> OrtResult<Session> {
fn with_model_downloaded_monomorphized(self, model: &str) -> OrtResult<Session> {
let download_dir = env::current_dir().map_err(OrtDownloadError::IoError)?;
let downloaded_path = model.download_to(download_dir)?;
let downloaded_path = self.download_to(model, download_dir)?;
self.with_model_from_file(downloaded_path)
}
#[cfg(feature = "fetch-models")]
#[tracing::instrument]
fn download_to<P>(&self, url: &str, download_dir: P) -> OrtResult<PathBuf>
where
P: AsRef<Path> + std::fmt::Debug
{
let model_filename = PathBuf::from(url.split('/').last().unwrap());
let model_filepath = download_dir.as_ref().join(model_filename);
if model_filepath.exists() {
tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download");
Ok(model_filepath)
} else {
tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), url = format!("{:?}", url).as_str(), "Downloading model");
let resp = ureq::get(url)
.timeout(Duration::from_secs(180))
.call()
.map_err(Box::new)
.map_err(OrtDownloadError::FetchError)?;
assert!(resp.has("Content-Length"));
let len = resp.header("Content-Length").and_then(|s| s.parse::<usize>().ok()).unwrap();
tracing::info!(len, "Downloading {} bytes", len);
let mut reader = resp.into_reader();
let f = std::fs::File::create(&model_filepath).unwrap();
let mut writer = std::io::BufWriter::new(f);
let bytes_io_count = std::io::copy(&mut reader, &mut writer).map_err(OrtDownloadError::IoError)?;
if bytes_io_count == len as u64 {
Ok(model_filepath)
} else {
Err(OrtDownloadError::CopyError {
expected: len as u64,
io: bytes_io_count
}
.into())
}
}
}
// TODO: Add all functions changing the options.
// See all OrtApi methods taking a `options: *mut OrtSessionOptions`.