firls-rs/firls-rs-macros/src/lib.rs

212 lines
5.9 KiB
Rust

extern crate num;
extern crate proc_macro;
extern crate quote;
use firls_rs::{firls, frequency_shift_coeffs};
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::quote;
use syn::parse::{Parse, ParseStream};
use syn::parse_macro_input;
use syn::spanned::Spanned;
use syn::{punctuated::Punctuated, Expr, Lit, Token};
#[proc_macro]
pub fn firls_real(input: TokenStream) -> TokenStream {
let FirlsRealInput {
filter_len,
sampling_frequency,
bands,
span,
} = parse_macro_input!(input as FirlsRealInput);
let output = match firls(filter_len, sampling_frequency, &bands) {
Ok(coeffs) => {
quote! {
[
#(#coeffs),*
]
}
}
Err(msg) => syn::Error::new(span, msg).to_compile_error(),
};
proc_macro::TokenStream::from(output)
}
struct FirlsRealInput {
filter_len: usize,
sampling_frequency: f32,
bands: Vec<(f32, f32)>,
span: Span,
}
impl Parse for FirlsRealInput {
fn parse(input: ParseStream) -> Result<Self, syn::Error> {
let arg_list = Punctuated::<Expr, Token![,]>::parse_separated_nonempty(input)?;
if arg_list.len() != 3 {
return Err(syn::Error::new(
arg_list.span(),
"firls_real takes 3 parameters",
));
}
Ok(FirlsRealInput {
filter_len: parse_filter_len(&arg_list[0])?,
sampling_frequency: parse_frequency(&arg_list[1])?,
bands: parse_band_list(&arg_list[2])?,
span: arg_list.span(),
})
}
}
#[proc_macro]
pub fn firls_complex(input: TokenStream) -> TokenStream {
let FirlsComplexInput {
filter_len,
sampling_frequency,
bands,
frequency_shift,
span,
} = parse_macro_input!(input as FirlsComplexInput);
let output = match firls(filter_len, sampling_frequency, &bands) {
Ok(coeffs) => {
let coeffs = frequency_shift_coeffs(&coeffs, sampling_frequency, frequency_shift);
let constructors: Vec<proc_macro2::TokenStream> = coeffs
.iter()
.map(|num::Complex::<f32> { re, im }| {
quote! {
num::Complex::<f32> {re: #re, im: #im}
}
})
.collect();
quote! {
[
#(#constructors),*
]
}
}
Err(msg) => syn::Error::new(span, msg).to_compile_error(),
};
proc_macro::TokenStream::from(output)
}
struct FirlsComplexInput {
filter_len: usize,
sampling_frequency: f32,
bands: Vec<(f32, f32)>,
frequency_shift: f32,
span: Span,
}
impl Parse for FirlsComplexInput {
fn parse(input: ParseStream) -> Result<Self, syn::Error> {
let arg_list = Punctuated::<Expr, Token![,]>::parse_separated_nonempty(input)?;
if arg_list.len() != 4 {
return Err(syn::Error::new(
arg_list.span(),
"firls_complex takes 4 parameters",
));
}
Ok(FirlsComplexInput {
filter_len: parse_filter_len(&arg_list[0])?,
sampling_frequency: parse_frequency(&arg_list[1])?,
bands: parse_band_list(&arg_list[2])?,
frequency_shift: parse_frequency(&arg_list[3])?,
span: arg_list.span(),
})
}
}
fn parse_filter_len(len: &Expr) -> Result<usize, syn::Error> {
match len {
Expr::Lit(expr) => match &expr.lit {
Lit::Int(int_lit) => int_lit.base10_parse(),
_ => Err(syn::Error::new(
expr.span(),
"expected integer literal for len",
)),
},
_ => Err(syn::Error::new(
len.span(),
"len should be a literal expression",
)),
}
}
fn parse_frequency(freq: &Expr) -> Result<f32, syn::Error> {
match freq {
Expr::Lit(expr) => match &expr.lit {
Lit::Float(float_lit) => float_lit.base10_parse(),
_ => Err(syn::Error::new(
expr.span(),
"expected float literal for frequency",
)),
},
_ => Err(syn::Error::new(
freq.span(),
"frequency should be a literal expression",
)),
}
}
fn parse_gain(gain: &Expr) -> Result<f32, syn::Error> {
match gain {
Expr::Lit(expr) => match &expr.lit {
Lit::Float(float_lit) => float_lit.base10_parse(),
_ => Err(syn::Error::new(
expr.span(),
"expected float literal for gain",
)),
},
_ => Err(syn::Error::new(
gain.span(),
"gain should be a literal expression",
)),
}
}
fn parse_band_list(bands: &Expr) -> Result<Vec<(f32, f32)>, syn::Error> {
match bands {
Expr::Array(array_expr) => {
let mut result = Vec::new();
for elem in array_expr.elems.iter() {
let parsed_tuple = parse_band_tuple(elem)?;
result.push(parsed_tuple);
}
Ok(result)
}
_ => Err(syn::Error::new(
bands.span(),
"bands should be an array expression",
)),
}
}
fn parse_band_tuple(tuple: &Expr) -> Result<(f32, f32), syn::Error> {
match tuple {
Expr::Tuple(tuple_expr) => {
if tuple_expr.elems.len() == 2 {
let freq = parse_frequency(&tuple_expr.elems[0])?;
let gain = parse_gain(&tuple_expr.elems[1])?;
Ok((freq, gain))
} else {
Err(syn::Error::new(
tuple.span(),
"band points should be a two element tuple",
))
}
}
_ => Err(syn::Error::new(
tuple.span(),
"band points should be a tuple expression",
)),
}
}