Skip to content

Commit

Permalink
Merge pull request #743 from schungx/master
Browse files Browse the repository at this point in the history
Fix comparisons.
  • Loading branch information
schungx committed Jul 12, 2023
2 parents ed5bbf0 + cafd0aa commit 84fd228
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 53 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Bug fixes
---------

* Fixes a panic when using `this` as the first parameter in a namespace-qualified function call.
* Comparing two different data types (e.g. a custom type and a standard type) now correctly defaults to `false` (except for `!=` which defaults to `true`).
* `max` and `min` for integers, strings and characters were missing from the standard library. They are now added.

New features
Expand Down
53 changes: 45 additions & 8 deletions src/func/builtin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,46 @@ fn const_true_fn(_: Option<NativeCallContext>, _: &mut [&mut Dynamic]) -> RhaiRe
fn const_false_fn(_: Option<NativeCallContext>, _: &mut [&mut Dynamic]) -> RhaiResult {
Ok(Dynamic::FALSE)
}
/// Returns true if the type is numeric.
#[inline(always)]
fn is_numeric(typ: TypeId) -> bool {
if typ == TypeId::of::<INT>() {
return true;
}

#[cfg(not(feature = "no_float"))]
if typ == TypeId::of::<f32>() || typ == TypeId::of::<f64>() {
return true;
}

#[cfg(feature = "decimal")]
if typ == TypeId::of::<Decimal>() {
return true;
}

#[cfg(not(feature = "only_i32"))]
#[cfg(not(feature = "only_i64"))]
if typ == TypeId::of::<u8>()
|| typ == TypeId::of::<u16>()
|| typ == TypeId::of::<u32>()
|| typ == TypeId::of::<u64>()
|| typ == TypeId::of::<i8>()
|| typ == TypeId::of::<i16>()
|| typ == TypeId::of::<i32>()
|| typ == TypeId::of::<i64>()
{
return true;
}

#[cfg(not(feature = "only_i32"))]
#[cfg(not(feature = "only_i64"))]
#[cfg(not(target_family = "wasm"))]
if typ == TypeId::of::<u128>() || typ == TypeId::of::<i128>() {
return true;
}

false
}

/// Build in common binary operator implementations to avoid the cost of calling a registered function.
///
Expand Down Expand Up @@ -540,16 +580,13 @@ pub fn get_builtin_binary_op_fn(op: &Token, x: &Dynamic, y: &Dynamic) -> Option<
};
}

// One of the operands is a custom type, so it is never built-in
if x.is_variant() || y.is_variant() {
return None;
}

// Default comparison operators for different types
// Default comparison operators for different, non-numeric types
if type2 != type1 {
return match op {
NotEqualsTo => Some((const_true_fn, false)),
EqualsTo | GreaterThan | GreaterThanEqualsTo | LessThan | LessThanEqualsTo => {
NotEqualsTo if !is_numeric(type1) || !is_numeric(type2) => Some((const_true_fn, false)),
EqualsTo | GreaterThan | GreaterThanEqualsTo | LessThan | LessThanEqualsTo
if !is_numeric(type1) || !is_numeric(type2) =>
{
Some((const_false_fn, false))
}
_ => None,
Expand Down
94 changes: 60 additions & 34 deletions src/func/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::tokenizer::{is_valid_function_name, Token};
use crate::types::dynamic::Union;
use crate::{
calc_fn_hash, calc_fn_hash_full, Dynamic, Engine, FnArgsVec, FnPtr, ImmutableString,
OptimizationLevel, Position, RhaiResult, RhaiResultOf, Scope, Shared, ERR,
OptimizationLevel, Position, RhaiResult, RhaiResultOf, Scope, Shared, SmartString, ERR,
};
#[cfg(feature = "no_std")]
use hashbrown::hash_map::Entry;
Expand Down Expand Up @@ -1655,7 +1655,7 @@ impl Engine {
}

// Short-circuit native binary operator call if under Fast Operators mode
if op_token.is_some() && self.fast_operators() && args.len() == 2 {
if self.fast_operators() && args.len() == 2 && op_token.is_some() {
#[allow(clippy::wildcard_imports)]
use Token::*;

Expand All @@ -1665,7 +1665,7 @@ impl Engine {
.flatten();

let mut rhs = self
.get_arg_value(global, caches, scope, this_ptr, &args[1])?
.get_arg_value(global, caches, scope, this_ptr.as_deref_mut(), &args[1])?
.0
.flatten();

Expand All @@ -1679,8 +1679,8 @@ impl Engine {
_ => (),
},
(Union::Bool(b1, ..), Union::Bool(b2, ..)) => match op_token.unwrap() {
EqualsTo => return Ok((*b1 == *b2).into()),
NotEqualsTo => return Ok((*b1 != *b2).into()),
EqualsTo => return Ok((b1 == b2).into()),
NotEqualsTo => return Ok((b1 != b2).into()),
GreaterThan | GreaterThanEqualsTo | LessThan | LessThanEqualsTo => {
return Ok(Dynamic::FALSE)
}
Expand All @@ -1695,12 +1695,12 @@ impl Engine {

#[cfg(not(feature = "unchecked"))]
match op_token.unwrap() {
EqualsTo => return Ok((*n1 == *n2).into()),
NotEqualsTo => return Ok((*n1 != *n2).into()),
GreaterThan => return Ok((*n1 > *n2).into()),
GreaterThanEqualsTo => return Ok((*n1 >= *n2).into()),
LessThan => return Ok((*n1 < *n2).into()),
LessThanEqualsTo => return Ok((*n1 <= *n2).into()),
EqualsTo => return Ok((n1 == n2).into()),
NotEqualsTo => return Ok((n1 != n2).into()),
GreaterThan => return Ok((n1 > n2).into()),
GreaterThanEqualsTo => return Ok((n1 >= n2).into()),
LessThan => return Ok((n1 < n2).into()),
LessThanEqualsTo => return Ok((n1 <= n2).into()),
Plus => return add(*n1, *n2).map(Into::into),
Minus => return subtract(*n1, *n2).map(Into::into),
Multiply => return multiply(*n1, *n2).map(Into::into),
Expand All @@ -1710,17 +1710,17 @@ impl Engine {
}
#[cfg(feature = "unchecked")]
match op_token.unwrap() {
EqualsTo => return Ok((*n1 == *n2).into()),
NotEqualsTo => return Ok((*n1 != *n2).into()),
GreaterThan => return Ok((*n1 > *n2).into()),
GreaterThanEqualsTo => return Ok((*n1 >= *n2).into()),
LessThan => return Ok((*n1 < *n2).into()),
LessThanEqualsTo => return Ok((*n1 <= *n2).into()),
Plus => return Ok((*n1 + *n2).into()),
Minus => return Ok((*n1 - *n2).into()),
Multiply => return Ok((*n1 * *n2).into()),
Divide => return Ok((*n1 / *n2).into()),
Modulo => return Ok((*n1 % *n2).into()),
EqualsTo => return Ok((n1 == n2).into()),
NotEqualsTo => return Ok((n1 != n2).into()),
GreaterThan => return Ok((n1 > n2).into()),
GreaterThanEqualsTo => return Ok((n1 >= n2).into()),
LessThan => return Ok((n1 < n2).into()),
LessThanEqualsTo => return Ok((n1 <= n2).into()),
Plus => return Ok((n1 + n2).into()),
Minus => return Ok((n1 - n2).into()),
Multiply => return Ok((n1 * n2).into()),
Divide => return Ok((n1 / n2).into()),
Modulo => return Ok((n1 % n2).into()),
_ => (),
}
}
Expand Down Expand Up @@ -1776,24 +1776,50 @@ impl Engine {
GreaterThanEqualsTo => return Ok((s1 >= s2).into()),
LessThan => return Ok((s1 < s2).into()),
LessThanEqualsTo => return Ok((s1 <= s2).into()),
Plus => {
#[cfg(not(feature = "unchecked"))]
self.throw_on_size((0, 0, s1.len() + s2.len()))?;
return Ok((s1 + s2).into());
}
Minus => return Ok((s1 - s2).into()),
_ => (),
},
(Union::Char(c1, ..), Union::Char(c2, ..)) => match op_token.unwrap() {
EqualsTo => return Ok((c1 == c2).into()),
NotEqualsTo => return Ok((c1 != c2).into()),
GreaterThan => return Ok((c1 > c2).into()),
GreaterThanEqualsTo => return Ok((c1 >= c2).into()),
LessThan => return Ok((c1 < c2).into()),
LessThanEqualsTo => return Ok((c1 <= c2).into()),
Plus => {
let mut result = SmartString::new_const();
result.push(*c1);
result.push(*c2);

#[cfg(not(feature = "unchecked"))]
self.throw_on_size((0, 0, result.len()))?;

return Ok(result.into());
}
_ => (),
},
_ => (),
(Union::Variant(..), _) | (_, Union::Variant(..)) => (),
_ => {
if let Some((func, need_context)) =
get_builtin_binary_op_fn(op_token.as_ref().unwrap(), &mut lhs, &mut rhs)
{
// We may not need to bump the level because built-in's do not need it.
//defer! { let orig_level = global.level; global.level += 1 }

let context =
need_context.then(|| (self, name.as_str(), None, &*global, pos).into());
return func(context, &mut [&mut lhs, &mut rhs]);
}
}
}

let operands = &mut [&mut lhs, &mut rhs];

if let Some((func, need_context)) =
get_builtin_binary_op_fn(op_token.as_ref().unwrap(), operands[0], operands[1])
{
// We may not need to bump the level because built-in's do not need it.
//defer! { let orig_level = global.level; global.level += 1 }

let context =
need_context.then(|| (self, name.as_str(), None, &*global, pos).into());
return func(context, operands);
}

return self
.exec_fn_call(
global, caches, None, name, op_token, *hashes, operands, false, false, pos,
Expand Down
9 changes: 4 additions & 5 deletions tests/mismatched_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ fn test_mismatched_op_custom_type() -> Result<(), Box<EvalAltResult>> {
.register_type_with_name::<TestStruct>("TestStruct")
.register_fn("new_ts", TestStruct::new);

assert!(!engine.eval::<bool>("new_ts() == 42")?);

assert!(engine.eval::<bool>("new_ts() != ()")?);

assert!(matches!(*engine.eval::<bool>(
"
let x = new_ts();
Expand All @@ -54,11 +58,6 @@ fn test_mismatched_op_custom_type() -> Result<(), Box<EvalAltResult>> {
").unwrap_err(),
EvalAltResult::ErrorFunctionNotFound(f, ..) if f == "== (TestStruct, TestStruct)"));

assert!(
matches!(*engine.eval::<bool>("new_ts() == 42").unwrap_err(),
EvalAltResult::ErrorFunctionNotFound(f, ..) if f.starts_with("== (TestStruct, "))
);

assert!(matches!(
*engine.eval::<INT>("60 + new_ts()").unwrap_err(),
EvalAltResult::ErrorFunctionNotFound(f, ..) if f == format!("+ ({}, TestStruct)", std::any::type_name::<INT>())
Expand Down
14 changes: 8 additions & 6 deletions tests/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,7 @@ fn test_ops_other_number_types() -> Result<(), Box<EvalAltResult>> {
EvalAltResult::ErrorFunctionNotFound(f, ..) if f.starts_with("== (u16,")
));

assert!(
matches!(*engine.eval_with_scope::<bool>(&mut scope, r#"x == "hello""#).unwrap_err(),
EvalAltResult::ErrorFunctionNotFound(f, ..) if f.starts_with("== (u16,")
)
);
assert!(!engine.eval_with_scope::<bool>(&mut scope, r#"x == "hello""#)?);

Ok(())
}
Expand Down Expand Up @@ -82,9 +78,15 @@ fn test_ops_custom_types() -> Result<(), Box<EvalAltResult>> {
.register_type_with_name::<Test2>("Test2")
.register_fn("new_ts1", || Test1)
.register_fn("new_ts2", || Test2)
.register_fn("==", |x: Test1, y: Test2| true);
.register_fn("==", |_: Test1, _: Test2| true);

assert!(engine.eval::<bool>("let x = new_ts1(); let y = new_ts2(); x == y")?);

assert!(engine.eval::<bool>("let x = new_ts1(); let y = new_ts2(); x != y")?);

assert!(!engine.eval::<bool>("let x = new_ts1(); x == ()")?);

assert!(engine.eval::<bool>("let x = new_ts1(); x != ()")?);

Ok(())
}

0 comments on commit 84fd228

Please sign in to comment.