mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
feat: LoRA adapters
This commit is contained in:
52
src/adapter.rs
Normal file
52
src/adapter.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
use std::{
|
||||
path::Path,
|
||||
ptr::{self, NonNull},
|
||||
sync::Arc
|
||||
};
|
||||
|
||||
use crate::{Allocator, Result, ortsys, util};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct AdapterInner {
|
||||
pub(crate) ptr: NonNull<ort_sys::OrtLoraAdapter>
|
||||
}
|
||||
|
||||
impl Drop for AdapterInner {
|
||||
fn drop(&mut self) {
|
||||
ortsys![unsafe ReleaseLoraAdapter(self.ptr.as_ptr())];
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Adapter {
|
||||
pub(crate) inner: Arc<AdapterInner>
|
||||
}
|
||||
|
||||
impl Adapter {
|
||||
pub fn from_file(path: impl AsRef<Path>, allocator: Option<&Allocator>) -> Result<Self> {
|
||||
let path = util::path_to_os_char(path);
|
||||
let allocator_ptr = allocator.map(|c| c.ptr()).unwrap_or_else(ptr::null_mut);
|
||||
let mut ptr = ptr::null_mut();
|
||||
ortsys![unsafe CreateLoraAdapter(path.as_ptr(), allocator_ptr, &mut ptr)?];
|
||||
Ok(Adapter {
|
||||
inner: Arc::new(AdapterInner {
|
||||
ptr: unsafe { NonNull::new_unchecked(ptr) }
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_memory(bytes: &[u8], allocator: Option<&Allocator>) -> Result<Self> {
|
||||
let allocator_ptr = allocator.map(|c| c.ptr()).unwrap_or_else(ptr::null_mut);
|
||||
let mut ptr = ptr::null_mut();
|
||||
ortsys![unsafe CreateLoraAdapterFromArray(bytes.as_ptr().cast(), bytes.len(), allocator_ptr, &mut ptr)?];
|
||||
Ok(Adapter {
|
||||
inner: Arc::new(AdapterInner {
|
||||
ptr: unsafe { NonNull::new_unchecked(ptr) }
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn ptr(&self) -> *mut ort_sys::OrtLoraAdapter {
|
||||
self.inner.ptr.as_ptr()
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,7 @@
|
||||
#[cfg(all(test, not(feature = "fetch-models")))]
|
||||
compile_error!("`cargo test --features fetch-models`!!1!");
|
||||
|
||||
pub(crate) mod adapter;
|
||||
pub(crate) mod environment;
|
||||
pub(crate) mod error;
|
||||
pub(crate) mod execution_providers;
|
||||
@@ -48,6 +49,7 @@ pub use self::tensor::ArrayExtensions;
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "training")))]
|
||||
pub use self::training::*;
|
||||
pub use self::{
|
||||
adapter::Adapter,
|
||||
environment::{Environment, EnvironmentBuilder, EnvironmentGlobalThreadPoolOptions, get_environment, init},
|
||||
error::{Error, ErrorCode, Result},
|
||||
execution_providers::*,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::{collections::HashMap, ffi::CString, marker::PhantomData, ptr::NonNull, sync::Arc};
|
||||
|
||||
use crate::{
|
||||
adapter::{Adapter, AdapterInner},
|
||||
error::Result,
|
||||
ortsys,
|
||||
session::Output,
|
||||
@@ -157,6 +158,7 @@ impl SelectedOutputMarker for HasSelectedOutputs {}
|
||||
pub struct RunOptions<O: SelectedOutputMarker = NoSelectedOutputs> {
|
||||
pub(crate) run_options_ptr: NonNull<ort_sys::OrtRunOptions>,
|
||||
pub(crate) outputs: OutputSelector,
|
||||
adapters: Vec<Arc<AdapterInner>>,
|
||||
_marker: PhantomData<O>
|
||||
}
|
||||
|
||||
@@ -175,6 +177,7 @@ impl RunOptions {
|
||||
Ok(RunOptions {
|
||||
run_options_ptr: unsafe { NonNull::new_unchecked(run_options_ptr) },
|
||||
outputs: OutputSelector::default(),
|
||||
adapters: Vec::new(),
|
||||
_marker: PhantomData
|
||||
})
|
||||
}
|
||||
@@ -303,6 +306,12 @@ impl<O: SelectedOutputMarker> RunOptions<O> {
|
||||
ortsys![unsafe AddRunConfigEntry(self.run_options_ptr.as_ptr(), key.as_ptr(), value.as_ptr())?];
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn add_adapter(&mut self, adapter: &Adapter) -> Result<()> {
|
||||
ortsys![unsafe RunOptionsAddActiveLoraAdapter(self.run_options_ptr.as_ptr(), adapter.ptr())?];
|
||||
self.adapters.push(Arc::clone(&adapter.inner));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<O: SelectedOutputMarker> Drop for RunOptions<O> {
|
||||
|
||||
Reference in New Issue
Block a user