feat: extract num_complex values from complex tensors

This commit is contained in:
Carson M.
2025-03-10 23:40:57 -05:00
parent 7c603aa47e
commit b550675ced
2 changed files with 8 additions and 0 deletions

View File

@@ -64,6 +64,7 @@ training = [ "ort-sys/training" ]
ndarray = [ "dep:ndarray" ]
half = [ "dep:half" ]
num-complex = [ "dep:num-complex" ]
tracing = [ "dep:tracing" ]
fetch-models = [ "std", "dep:ureq", "dep:sha2" ]
@@ -102,6 +103,7 @@ ureq = { version = "3", optional = true, default-features = false, features = [
sha2 = { version = "0.10", optional = true }
tracing = { version = "0.1", optional = true, default-features = false }
half = { version = "2.1", default-features = false, optional = true }
num-complex = { version = "0.4", default-features = false, optional = true }
[dev-dependencies]
anyhow = "1.0"

View File

@@ -208,6 +208,12 @@ impl_type_trait!(u64, Uint64);
#[cfg(feature = "half")]
#[cfg_attr(docsrs, doc(cfg(feature = "half")))]
impl_type_trait!(half::bf16, Bfloat16);
#[cfg(feature = "num-complex")]
#[cfg_attr(docsrs, doc(cfg(feature = "num-complex")))]
impl_type_trait!(num_complex::Complex32, Complex64);
#[cfg(feature = "num-complex")]
#[cfg_attr(docsrs, doc(cfg(feature = "num-complex")))]
impl_type_trait!(num_complex::Complex64, Complex128);
impl IntoTensorElementType for String {
fn into_tensor_element_type() -> TensorElementType {