mirror of
https://github.com/pykeio/ort
synced 2026-04-25 16:34:55 +02:00
feat: tests for adapter API
This commit is contained in:
@@ -50,3 +50,58 @@ impl Adapter {
|
|||||||
self.inner.ptr.as_ptr()
|
self.inner.ptr.as_ptr()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::fs;
|
||||||
|
|
||||||
|
use super::Adapter;
|
||||||
|
use crate::{RunOptions, Session, Tensor};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_lora() -> crate::Result<()> {
|
||||||
|
let model = Session::builder()?.commit_from_file("tests/data/lora_model.onnx")?;
|
||||||
|
let lora = Adapter::from_file("tests/data/adapter.orl", None)?;
|
||||||
|
|
||||||
|
let mut run_options = RunOptions::new()?;
|
||||||
|
run_options.add_adapter(&lora)?;
|
||||||
|
|
||||||
|
let output: Tensor<f32> = model
|
||||||
|
.run_with_options(crate::inputs![Tensor::<f32>::from_array(([4, 4], vec![1.0; 16]))?]?, &run_options)?
|
||||||
|
.remove("output")
|
||||||
|
.expect("")
|
||||||
|
.downcast()?;
|
||||||
|
let (_, output) = output.extract_raw_tensor();
|
||||||
|
assert_eq!(output[0], 154.0);
|
||||||
|
assert_eq!(output[1], 176.0);
|
||||||
|
assert_eq!(output[2], 198.0);
|
||||||
|
assert_eq!(output[3], 220.0);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_lora_from_memory() -> crate::Result<()> {
|
||||||
|
let model = Session::builder()?.commit_from_file("tests/data/lora_model.onnx")?;
|
||||||
|
|
||||||
|
let lora_bytes = fs::read("tests/data/adapter.orl").expect("");
|
||||||
|
let lora = Adapter::from_memory(&lora_bytes, None)?;
|
||||||
|
drop(lora_bytes);
|
||||||
|
|
||||||
|
let mut run_options = RunOptions::new()?;
|
||||||
|
run_options.add_adapter(&lora)?;
|
||||||
|
|
||||||
|
let output: Tensor<f32> = model
|
||||||
|
.run_with_options(crate::inputs![Tensor::<f32>::from_array(([4, 4], vec![1.0; 16]))?]?, &run_options)?
|
||||||
|
.remove("output")
|
||||||
|
.expect("")
|
||||||
|
.downcast()?;
|
||||||
|
let (_, output) = output.extract_raw_tensor();
|
||||||
|
assert_eq!(output[0], 154.0);
|
||||||
|
assert_eq!(output[1], 176.0);
|
||||||
|
assert_eq!(output[2], 198.0);
|
||||||
|
assert_eq!(output[3], 220.0);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
BIN
tests/data/adapter.orl
Normal file
BIN
tests/data/adapter.orl
Normal file
Binary file not shown.
BIN
tests/data/lora_model.onnx
Normal file
BIN
tests/data/lora_model.onnx
Normal file
Binary file not shown.
84
tools/test-data/generate-models.py
Normal file
84
tools/test-data/generate-models.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
import inspect
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import onnx
|
||||||
|
import onnx.helper as G
|
||||||
|
import onnxruntime as ort
|
||||||
|
|
||||||
|
def make_tensor_from_np(name: str, arr: np.ndarray) -> onnx.TensorProto:
|
||||||
|
match arr.dtype:
|
||||||
|
case np.float32:
|
||||||
|
dtype = onnx.TensorProto.FLOAT
|
||||||
|
shape = list(arr.shape)
|
||||||
|
return G.make_tensor(name, dtype, shape, arr)
|
||||||
|
|
||||||
|
factories = []
|
||||||
|
def model_factory(func):
|
||||||
|
model_name = f'{func.__name__}.onnx'
|
||||||
|
def wrapper():
|
||||||
|
try:
|
||||||
|
model = func()
|
||||||
|
except Exception as e:
|
||||||
|
print(f'Failed to create `{model_name}`: {e}')
|
||||||
|
return
|
||||||
|
if isinstance(model, onnx.GraphProto):
|
||||||
|
model = G.make_model(model, opset_imports=[onnx.OperatorSetIdProto(domain=None, version=21)])
|
||||||
|
try:
|
||||||
|
onnx.checker.check_model(model)
|
||||||
|
except Exception as e:
|
||||||
|
print(f'`{model_name}` is invalid: {e}')
|
||||||
|
onnx.save_model(model, f'tests/data/{model_name}')
|
||||||
|
factories.append(wrapper)
|
||||||
|
return wrapper
|
||||||
|
def misc_factory(func):
|
||||||
|
def wrapper():
|
||||||
|
try:
|
||||||
|
func()
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
factories.append(wrapper)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
class Models:
|
||||||
|
@model_factory
|
||||||
|
def lora_model():
|
||||||
|
input_x = G.make_tensor_value_info('input', onnx.TensorProto.FLOAT, [4, 4])
|
||||||
|
|
||||||
|
lora_param_a_input = G.make_tensor_value_info('lora_param_a', onnx.TensorProto.FLOAT, [4, 'dim'])
|
||||||
|
lora_param_b_input = G.make_tensor_value_info('lora_param_b', onnx.TensorProto.FLOAT, ['dim', 4])
|
||||||
|
|
||||||
|
output = onnx.helper.make_tensor_value_info('output', onnx.TensorProto.FLOAT, [4, 4])
|
||||||
|
|
||||||
|
weight_x = make_tensor_from_np('weight_x', np.array(range(1, 17)).reshape(4, 4).astype(np.float32))
|
||||||
|
|
||||||
|
lora_param_a = make_tensor_from_np('lora_param_a', np.zeros([4, 0], dtype=np.float32))
|
||||||
|
lora_param_b = make_tensor_from_np('lora_param_b', np.zeros([0, 4], dtype=np.float32))
|
||||||
|
|
||||||
|
matmul_x = G.make_node('MatMul', ['input', 'weight_x'], ['mm_output_x'])
|
||||||
|
matmul_a = G.make_node('MatMul', ['input', 'lora_param_a'], ['mm_output_a'])
|
||||||
|
matmul_b = G.make_node('MatMul', ['mm_output_a', 'lora_param_b'], ['mm_output_b'])
|
||||||
|
add_node = G.make_node('Add', ['mm_output_x', 'mm_output_b'], ['output'])
|
||||||
|
|
||||||
|
return G.make_graph(
|
||||||
|
nodes=[matmul_x, matmul_a, matmul_b, add_node],
|
||||||
|
inputs=[input_x, lora_param_a_input, lora_param_b_input],
|
||||||
|
outputs=[output],
|
||||||
|
initializer=[weight_x, lora_param_a, lora_param_b],
|
||||||
|
name='lora_test'
|
||||||
|
)
|
||||||
|
|
||||||
|
@misc_factory
|
||||||
|
def lora_adapter():
|
||||||
|
param_a = ort.OrtValue.ortvalue_from_numpy(np.array([[3], [4], [5], [6]], dtype=np.float32))
|
||||||
|
param_b = ort.OrtValue.ortvalue_from_numpy(np.array([[7, 8, 9, 10]], dtype=np.float32))
|
||||||
|
|
||||||
|
adapter = ort.AdapterFormat()
|
||||||
|
adapter.set_parameters({
|
||||||
|
'lora_param_a': param_a,
|
||||||
|
'lora_param_b': param_b
|
||||||
|
})
|
||||||
|
adapter.export_adapter('tests/data/adapter.orl')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
for model_factory in factories:
|
||||||
|
model_factory()
|
||||||
Reference in New Issue
Block a user