aboutsummaryrefslogtreecommitdiff
path: root/src/expression_function.rs
diff options
context:
space:
mode:
authorNathan Reiner <nathan@nathanreiner.xyz>2024-01-17 23:28:48 +0100
committerNathan Reiner <nathan@nathanreiner.xyz>2024-01-17 23:28:48 +0100
commit9cc61497ed8a2336f33407d3262181e4ac3b46cb (patch)
treeaac75b57b0fffea64abcd23cbac4d27c875fee48 /src/expression_function.rs
parent77cf9aa7535a1d9481f0bd3caeea26e2b85c5019 (diff)
add commonsense and expression_functions
Diffstat (limited to 'src/expression_function.rs')
-rw-r--r--src/expression_function.rs60
1 files changed, 60 insertions, 0 deletions
diff --git a/src/expression_function.rs b/src/expression_function.rs
new file mode 100644
index 0000000..35eba5f
--- /dev/null
+++ b/src/expression_function.rs
@@ -0,0 +1,60 @@
+use std::{collections::HashMap, iter::zip};
+
+use crate::{
+ commonsense_functions, commonsense_operations, functions,
+ complex::Complex,
+ context::Context,
+ expression::Expression,
+ operation::Operation,
+ function::{Function, FunctionArgument}
+};
+
+pub struct ExpressionFunction {
+ expr: Expression,
+ name: String,
+ args: Vec<String>,
+}
+
+impl ExpressionFunction {
+ pub fn from_string(str: String) -> Self {
+ let str = str.replace(' ', "");
+ let (lhs, expr) = str.split_once('=').unwrap();
+ let (name, a) = lhs.split_once('(').unwrap();
+ let args: Vec<String> = a[0..a.len() - 1]
+ .split(',')
+ .into_iter()
+ .map(|s| s.to_string())
+ .collect();
+ Self {
+ expr: Expression::from_string(expr),
+ name: name.to_string(),
+ args,
+ }
+ }
+
+ pub fn name(&self) -> &str {
+ &self.name
+ }
+}
+
+impl Function for ExpressionFunction {
+ fn eval(&self, args: FunctionArgument) -> Result<Complex, String> {
+ if args.len() == self.args.len() {
+ let mut vars = HashMap::new();
+ for (n, v) in zip(self.args.iter(), args.data().iter()) {
+ vars.insert(n.to_string(), v.clone());
+ }
+ let ctx = Context::new()
+ .with_variables(vars)
+ .with_functions(commonsense_functions! {})
+ .with_operations(commonsense_operations! {});
+ self.expr.evaluate(&ctx)
+ } else {
+ Err(format!(
+ "{} takes {} parameters",
+ self.name,
+ self.args.len()
+ ))
+ }
+ }
+}