mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
refactor: simplify downloaded models
This commit is contained in:
@@ -1,83 +1,6 @@
|
||||
#[cfg(feature = "fetch-models")]
|
||||
use std::{
|
||||
fs, io,
|
||||
path::{Path, PathBuf},
|
||||
time::Duration
|
||||
};
|
||||
|
||||
#[cfg(feature = "fetch-models")]
|
||||
use tracing::info;
|
||||
|
||||
#[cfg(feature = "fetch-models")]
|
||||
use crate::error::{OrtDownloadError, OrtResult};
|
||||
|
||||
pub mod language;
|
||||
pub mod vision;
|
||||
|
||||
/// Available pre-trained models to download from the [ONNX Model Zoo](https://github.com/onnx/models).
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum OnnxModel {
|
||||
/// Computer vision models
|
||||
Vision(vision::Vision),
|
||||
/// Language models
|
||||
Language(language::Language)
|
||||
}
|
||||
|
||||
trait ModelUrl {
|
||||
pub trait ModelUrl {
|
||||
fn fetch_url(&self) -> &'static str;
|
||||
}
|
||||
|
||||
impl ModelUrl for OnnxModel {
|
||||
fn fetch_url(&self) -> &'static str {
|
||||
match self {
|
||||
OnnxModel::Vision(model) => model.fetch_url(),
|
||||
OnnxModel::Language(model) => model.fetch_url()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl OnnxModel {
|
||||
#[cfg(feature = "fetch-models")]
|
||||
#[tracing::instrument]
|
||||
pub(crate) fn download_to<P>(&self, download_dir: P) -> OrtResult<PathBuf>
|
||||
where
|
||||
P: AsRef<Path> + std::fmt::Debug
|
||||
{
|
||||
let url = self.fetch_url();
|
||||
|
||||
let model_filename = PathBuf::from(url.split('/').last().unwrap());
|
||||
let model_filepath = download_dir.as_ref().join(model_filename);
|
||||
if model_filepath.exists() {
|
||||
info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download");
|
||||
Ok(model_filepath)
|
||||
} else {
|
||||
info!(model_filepath = format!("{}", model_filepath.display()).as_str(), url = format!("{:?}", url).as_str(), "Downloading model");
|
||||
|
||||
let resp = ureq::get(url)
|
||||
.timeout(Duration::from_secs(180))
|
||||
.call()
|
||||
.map_err(Box::new)
|
||||
.map_err(OrtDownloadError::FetchError)?;
|
||||
|
||||
assert!(resp.has("Content-Length"));
|
||||
let len = resp.header("Content-Length").and_then(|s| s.parse::<usize>().ok()).unwrap();
|
||||
info!(len, "Downloading {} bytes", len);
|
||||
|
||||
let mut reader = resp.into_reader();
|
||||
|
||||
let f = fs::File::create(&model_filepath).unwrap();
|
||||
let mut writer = io::BufWriter::new(f);
|
||||
|
||||
let bytes_io_count = io::copy(&mut reader, &mut writer).map_err(OrtDownloadError::IoError)?;
|
||||
if bytes_io_count == len as u64 {
|
||||
Ok(model_filepath)
|
||||
} else {
|
||||
Err(OrtDownloadError::CopyError {
|
||||
expected: len as u64,
|
||||
io: bytes_io_count
|
||||
}
|
||||
.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +1,3 @@
|
||||
use super::ModelUrl;
|
||||
|
||||
pub mod machine_comprehension;
|
||||
|
||||
pub use machine_comprehension::MachineComprehension;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Language {
|
||||
MachineComprehension(MachineComprehension)
|
||||
}
|
||||
|
||||
impl ModelUrl for Language {
|
||||
fn fetch_url(&self) -> &'static str {
|
||||
match self {
|
||||
Language::MachineComprehension(v) => v.fetch_url()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#![allow(clippy::upper_case_acronyms)]
|
||||
|
||||
use crate::download::{language::Language, ModelUrl, OnnxModel};
|
||||
use crate::download::ModelUrl;
|
||||
|
||||
/// Machine comprehension models.
|
||||
///
|
||||
@@ -61,21 +61,3 @@ impl ModelUrl for GPT2 {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<MachineComprehension> for OnnxModel {
|
||||
fn from(model: MachineComprehension) -> Self {
|
||||
OnnxModel::Language(Language::MachineComprehension(model))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RoBERTa> for OnnxModel {
|
||||
fn from(model: RoBERTa) -> Self {
|
||||
OnnxModel::Language(Language::MachineComprehension(MachineComprehension::RoBERTa(model)))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<GPT2> for OnnxModel {
|
||||
fn from(model: GPT2) -> Self {
|
||||
OnnxModel::Language(Language::MachineComprehension(MachineComprehension::GPT2(model)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
use super::ModelUrl;
|
||||
|
||||
pub mod body_face_gesture_analysis;
|
||||
pub mod domain_based_image_classification;
|
||||
pub mod image_classification;
|
||||
@@ -11,24 +9,3 @@ pub use domain_based_image_classification::DomainBasedImageClassification;
|
||||
pub use image_classification::ImageClassification;
|
||||
pub use image_manipulation::ImageManipulation;
|
||||
pub use object_detection_image_segmentation::ObjectDetectionImageSegmentation;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Vision {
|
||||
BodyFaceGestureAnalysis(BodyFaceGestureAnalysis),
|
||||
DomainBasedImageClassification(DomainBasedImageClassification),
|
||||
ImageClassification(ImageClassification),
|
||||
ImageManipulation(ImageManipulation),
|
||||
ObjectDetectionImageSegmentation(ObjectDetectionImageSegmentation)
|
||||
}
|
||||
|
||||
impl ModelUrl for Vision {
|
||||
fn fetch_url(&self) -> &'static str {
|
||||
match self {
|
||||
Vision::DomainBasedImageClassification(v) => v.fetch_url(),
|
||||
Vision::ImageClassification(v) => v.fetch_url(),
|
||||
Vision::ImageManipulation(v) => v.fetch_url(),
|
||||
Vision::ObjectDetectionImageSegmentation(v) => v.fetch_url(),
|
||||
Vision::BodyFaceGestureAnalysis(v) => v.fetch_url()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::download::{vision::Vision, ModelUrl, OnnxModel};
|
||||
use crate::download::ModelUrl;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum BodyFaceGestureAnalysis {
|
||||
@@ -19,9 +19,3 @@ impl ModelUrl for BodyFaceGestureAnalysis {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<BodyFaceGestureAnalysis> for OnnxModel {
|
||||
fn from(model: BodyFaceGestureAnalysis) -> Self {
|
||||
OnnxModel::Vision(Vision::BodyFaceGestureAnalysis(model))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::download::{vision::Vision, ModelUrl, OnnxModel};
|
||||
use crate::download::ModelUrl;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum DomainBasedImageClassification {
|
||||
@@ -13,9 +13,3 @@ impl ModelUrl for DomainBasedImageClassification {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DomainBasedImageClassification> for OnnxModel {
|
||||
fn from(model: DomainBasedImageClassification) -> Self {
|
||||
OnnxModel::Vision(Vision::DomainBasedImageClassification(model))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#![allow(clippy::upper_case_acronyms)]
|
||||
|
||||
use crate::download::{vision::Vision, ModelUrl, OnnxModel};
|
||||
use crate::download::ModelUrl;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ImageClassification {
|
||||
@@ -204,33 +204,3 @@ impl ModelUrl for ShuffleNetVersion {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ImageClassification> for OnnxModel {
|
||||
fn from(model: ImageClassification) -> Self {
|
||||
OnnxModel::Vision(Vision::ImageClassification(model))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ResNet> for OnnxModel {
|
||||
fn from(variant: ResNet) -> Self {
|
||||
OnnxModel::Vision(Vision::ImageClassification(ImageClassification::ResNet(variant)))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vgg> for OnnxModel {
|
||||
fn from(variant: Vgg) -> Self {
|
||||
OnnxModel::Vision(Vision::ImageClassification(ImageClassification::Vgg(variant)))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<InceptionVersion> for OnnxModel {
|
||||
fn from(variant: InceptionVersion) -> Self {
|
||||
OnnxModel::Vision(Vision::ImageClassification(ImageClassification::Inception(variant)))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ShuffleNetVersion> for OnnxModel {
|
||||
fn from(variant: ShuffleNetVersion) -> Self {
|
||||
OnnxModel::Vision(Vision::ImageClassification(ImageClassification::ShuffleNet(variant)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::download::{vision::Vision, ModelUrl, OnnxModel};
|
||||
use crate::download::ModelUrl;
|
||||
|
||||
/// Image Manipulation
|
||||
///
|
||||
@@ -54,15 +54,3 @@ impl ModelUrl for FastNeuralStyleTransferStyle {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ImageManipulation> for OnnxModel {
|
||||
fn from(model: ImageManipulation) -> Self {
|
||||
OnnxModel::Vision(Vision::ImageManipulation(model))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<FastNeuralStyleTransferStyle> for OnnxModel {
|
||||
fn from(style: FastNeuralStyleTransferStyle) -> Self {
|
||||
OnnxModel::Vision(Vision::ImageManipulation(ImageManipulation::FastNeuralStyleTransfer(style)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#![allow(clippy::upper_case_acronyms)]
|
||||
|
||||
use crate::download::{vision::Vision, ModelUrl, OnnxModel};
|
||||
use crate::download::ModelUrl;
|
||||
|
||||
/// Object Detection & Image Segmentation
|
||||
///
|
||||
@@ -88,9 +88,3 @@ impl ModelUrl for ObjectDetectionImageSegmentation {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ObjectDetectionImageSegmentation> for OnnxModel {
|
||||
fn from(model: ObjectDetectionImageSegmentation) -> Self {
|
||||
OnnxModel::Vision(Vision::ObjectDetectionImageSegmentation(model))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
#![allow(clippy::tabs_in_doc_comments)]
|
||||
|
||||
#[cfg(feature = "fetch-models")]
|
||||
use std::env;
|
||||
#[cfg(not(target_family = "windows"))]
|
||||
use std::os::unix::ffi::OsStrExt;
|
||||
#[cfg(target_family = "windows")]
|
||||
use std::os::windows::ffi::OsStrExt;
|
||||
#[cfg(feature = "fetch-models")]
|
||||
use std::{env, path::PathBuf, time::Duration};
|
||||
use std::{
|
||||
ffi::CString,
|
||||
fmt::{self, Debug},
|
||||
@@ -33,7 +33,7 @@ use super::{
|
||||
AllocatorType, GraphOptimizationLevel, MemType
|
||||
};
|
||||
#[cfg(feature = "fetch-models")]
|
||||
use super::{download::OnnxModel, error::OrtDownloadError};
|
||||
use super::{download::ModelUrl, error::OrtDownloadError};
|
||||
|
||||
/// Type used to create a session using the _builder pattern_.
|
||||
///
|
||||
@@ -266,18 +266,60 @@ impl SessionBuilder {
|
||||
#[cfg(feature = "fetch-models")]
|
||||
pub fn with_model_downloaded<M>(self, model: M) -> OrtResult<Session>
|
||||
where
|
||||
M: Into<OnnxModel>
|
||||
M: ModelUrl
|
||||
{
|
||||
self.with_model_downloaded_monomorphized(model.into())
|
||||
self.with_model_downloaded_monomorphized(model.fetch_url())
|
||||
}
|
||||
|
||||
#[cfg(feature = "fetch-models")]
|
||||
fn with_model_downloaded_monomorphized(self, model: OnnxModel) -> OrtResult<Session> {
|
||||
fn with_model_downloaded_monomorphized(self, model: &str) -> OrtResult<Session> {
|
||||
let download_dir = env::current_dir().map_err(OrtDownloadError::IoError)?;
|
||||
let downloaded_path = model.download_to(download_dir)?;
|
||||
let downloaded_path = self.download_to(model, download_dir)?;
|
||||
self.with_model_from_file(downloaded_path)
|
||||
}
|
||||
|
||||
#[cfg(feature = "fetch-models")]
|
||||
#[tracing::instrument]
|
||||
fn download_to<P>(&self, url: &str, download_dir: P) -> OrtResult<PathBuf>
|
||||
where
|
||||
P: AsRef<Path> + std::fmt::Debug
|
||||
{
|
||||
let model_filename = PathBuf::from(url.split('/').last().unwrap());
|
||||
let model_filepath = download_dir.as_ref().join(model_filename);
|
||||
if model_filepath.exists() {
|
||||
tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download");
|
||||
Ok(model_filepath)
|
||||
} else {
|
||||
tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), url = format!("{:?}", url).as_str(), "Downloading model");
|
||||
|
||||
let resp = ureq::get(url)
|
||||
.timeout(Duration::from_secs(180))
|
||||
.call()
|
||||
.map_err(Box::new)
|
||||
.map_err(OrtDownloadError::FetchError)?;
|
||||
|
||||
assert!(resp.has("Content-Length"));
|
||||
let len = resp.header("Content-Length").and_then(|s| s.parse::<usize>().ok()).unwrap();
|
||||
tracing::info!(len, "Downloading {} bytes", len);
|
||||
|
||||
let mut reader = resp.into_reader();
|
||||
|
||||
let f = std::fs::File::create(&model_filepath).unwrap();
|
||||
let mut writer = std::io::BufWriter::new(f);
|
||||
|
||||
let bytes_io_count = std::io::copy(&mut reader, &mut writer).map_err(OrtDownloadError::IoError)?;
|
||||
if bytes_io_count == len as u64 {
|
||||
Ok(model_filepath)
|
||||
} else {
|
||||
Err(OrtDownloadError::CopyError {
|
||||
expected: len as u64,
|
||||
io: bytes_io_count
|
||||
}
|
||||
.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Add all functions changing the options.
|
||||
// See all OrtApi methods taking a `options: *mut OrtSessionOptions`.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user