mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
feat: tests for adapter API
This commit is contained in:
@@ -50,3 +50,58 @@ impl Adapter {
|
||||
self.inner.ptr.as_ptr()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::fs;
|
||||
|
||||
use super::Adapter;
|
||||
use crate::{RunOptions, Session, Tensor};
|
||||
|
||||
#[test]
|
||||
fn test_lora() -> crate::Result<()> {
|
||||
let model = Session::builder()?.commit_from_file("tests/data/lora_model.onnx")?;
|
||||
let lora = Adapter::from_file("tests/data/adapter.orl", None)?;
|
||||
|
||||
let mut run_options = RunOptions::new()?;
|
||||
run_options.add_adapter(&lora)?;
|
||||
|
||||
let output: Tensor<f32> = model
|
||||
.run_with_options(crate::inputs![Tensor::<f32>::from_array(([4, 4], vec![1.0; 16]))?]?, &run_options)?
|
||||
.remove("output")
|
||||
.expect("")
|
||||
.downcast()?;
|
||||
let (_, output) = output.extract_raw_tensor();
|
||||
assert_eq!(output[0], 154.0);
|
||||
assert_eq!(output[1], 176.0);
|
||||
assert_eq!(output[2], 198.0);
|
||||
assert_eq!(output[3], 220.0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lora_from_memory() -> crate::Result<()> {
|
||||
let model = Session::builder()?.commit_from_file("tests/data/lora_model.onnx")?;
|
||||
|
||||
let lora_bytes = fs::read("tests/data/adapter.orl").expect("");
|
||||
let lora = Adapter::from_memory(&lora_bytes, None)?;
|
||||
drop(lora_bytes);
|
||||
|
||||
let mut run_options = RunOptions::new()?;
|
||||
run_options.add_adapter(&lora)?;
|
||||
|
||||
let output: Tensor<f32> = model
|
||||
.run_with_options(crate::inputs![Tensor::<f32>::from_array(([4, 4], vec![1.0; 16]))?]?, &run_options)?
|
||||
.remove("output")
|
||||
.expect("")
|
||||
.downcast()?;
|
||||
let (_, output) = output.extract_raw_tensor();
|
||||
assert_eq!(output[0], 154.0);
|
||||
assert_eq!(output[1], 176.0);
|
||||
assert_eq!(output[2], 198.0);
|
||||
assert_eq!(output[3], 220.0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
BIN
tests/data/adapter.orl
Normal file
BIN
tests/data/adapter.orl
Normal file
Binary file not shown.
BIN
tests/data/lora_model.onnx
Normal file
BIN
tests/data/lora_model.onnx
Normal file
Binary file not shown.
84
tools/test-data/generate-models.py
Normal file
84
tools/test-data/generate-models.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import inspect
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnx.helper as G
|
||||
import onnxruntime as ort
|
||||
|
||||
def make_tensor_from_np(name: str, arr: np.ndarray) -> onnx.TensorProto:
|
||||
match arr.dtype:
|
||||
case np.float32:
|
||||
dtype = onnx.TensorProto.FLOAT
|
||||
shape = list(arr.shape)
|
||||
return G.make_tensor(name, dtype, shape, arr)
|
||||
|
||||
factories = []
|
||||
def model_factory(func):
|
||||
model_name = f'{func.__name__}.onnx'
|
||||
def wrapper():
|
||||
try:
|
||||
model = func()
|
||||
except Exception as e:
|
||||
print(f'Failed to create `{model_name}`: {e}')
|
||||
return
|
||||
if isinstance(model, onnx.GraphProto):
|
||||
model = G.make_model(model, opset_imports=[onnx.OperatorSetIdProto(domain=None, version=21)])
|
||||
try:
|
||||
onnx.checker.check_model(model)
|
||||
except Exception as e:
|
||||
print(f'`{model_name}` is invalid: {e}')
|
||||
onnx.save_model(model, f'tests/data/{model_name}')
|
||||
factories.append(wrapper)
|
||||
return wrapper
|
||||
def misc_factory(func):
|
||||
def wrapper():
|
||||
try:
|
||||
func()
|
||||
except Exception as e:
|
||||
pass
|
||||
factories.append(wrapper)
|
||||
return wrapper
|
||||
|
||||
class Models:
|
||||
@model_factory
|
||||
def lora_model():
|
||||
input_x = G.make_tensor_value_info('input', onnx.TensorProto.FLOAT, [4, 4])
|
||||
|
||||
lora_param_a_input = G.make_tensor_value_info('lora_param_a', onnx.TensorProto.FLOAT, [4, 'dim'])
|
||||
lora_param_b_input = G.make_tensor_value_info('lora_param_b', onnx.TensorProto.FLOAT, ['dim', 4])
|
||||
|
||||
output = onnx.helper.make_tensor_value_info('output', onnx.TensorProto.FLOAT, [4, 4])
|
||||
|
||||
weight_x = make_tensor_from_np('weight_x', np.array(range(1, 17)).reshape(4, 4).astype(np.float32))
|
||||
|
||||
lora_param_a = make_tensor_from_np('lora_param_a', np.zeros([4, 0], dtype=np.float32))
|
||||
lora_param_b = make_tensor_from_np('lora_param_b', np.zeros([0, 4], dtype=np.float32))
|
||||
|
||||
matmul_x = G.make_node('MatMul', ['input', 'weight_x'], ['mm_output_x'])
|
||||
matmul_a = G.make_node('MatMul', ['input', 'lora_param_a'], ['mm_output_a'])
|
||||
matmul_b = G.make_node('MatMul', ['mm_output_a', 'lora_param_b'], ['mm_output_b'])
|
||||
add_node = G.make_node('Add', ['mm_output_x', 'mm_output_b'], ['output'])
|
||||
|
||||
return G.make_graph(
|
||||
nodes=[matmul_x, matmul_a, matmul_b, add_node],
|
||||
inputs=[input_x, lora_param_a_input, lora_param_b_input],
|
||||
outputs=[output],
|
||||
initializer=[weight_x, lora_param_a, lora_param_b],
|
||||
name='lora_test'
|
||||
)
|
||||
|
||||
@misc_factory
|
||||
def lora_adapter():
|
||||
param_a = ort.OrtValue.ortvalue_from_numpy(np.array([[3], [4], [5], [6]], dtype=np.float32))
|
||||
param_b = ort.OrtValue.ortvalue_from_numpy(np.array([[7, 8, 9, 10]], dtype=np.float32))
|
||||
|
||||
adapter = ort.AdapterFormat()
|
||||
adapter.set_parameters({
|
||||
'lora_param_a': param_a,
|
||||
'lora_param_b': param_b
|
||||
})
|
||||
adapter.export_adapter('tests/data/adapter.orl')
|
||||
|
||||
if __name__ == '__main__':
|
||||
for model_factory in factories:
|
||||
model_factory()
|
||||
Reference in New Issue
Block a user