From dc578e3488ffb4c8a26da7d0698960b88fc37fcc Mon Sep 17 00:00:00 2001 From: Mattias Ajander <mattias@ajander.se> Date: Fri, 14 Mar 2025 00:44:44 +0100 Subject: [PATCH] New Node evaluate() and get_value() pipeline with new NodeValue class --- Makefile | 3 + include/ast/ExpressionNode.h | 35 ----- include/ast/LiteralNode.h | 2 +- include/ast/Node.h | 12 +- include/ast/NodeValue.h | 133 +++++++++++++++++++ include/token/TokenType.h | 6 +- include/utils/Exception.h | 4 + source/ast/ExpressionNode.cc | 90 ------------- source/ast/LiteralNode.cc | 7 +- source/ast/NodeValue.cc | 240 +++++++++++++++++++++++++++++++++++ source/lexer/Lexer.cc | 3 +- source/token/Token.cc | 2 +- source/token/TokenType.cc | 6 +- tests/TestLiteralNode.cc | 130 +++++++++---------- tests/TestNodeValue.cc | 213 +++++++++++++++++++++++++++++++ 15 files changed, 678 insertions(+), 208 deletions(-) create mode 100644 include/ast/NodeValue.h create mode 100644 source/ast/NodeValue.cc create mode 100644 tests/TestNodeValue.cc diff --git a/Makefile b/Makefile index c8dfeb5..ace28ba 100644 --- a/Makefile +++ b/Makefile @@ -59,3 +59,6 @@ format: find $(SRC_DIR) $(INC_DIR) $(TEST_DIR) -name "*.cc" -o -name "*.h" | xargs clang-format -i .PHONY: all clean format directories lib tests + +# Compile on all cores +MAKEFLAGS += -j$(shell nproc) diff --git a/include/ast/ExpressionNode.h b/include/ast/ExpressionNode.h index 8a49a6d..dd21df2 100644 --- a/include/ast/ExpressionNode.h +++ b/include/ast/ExpressionNode.h @@ -32,41 +32,6 @@ public: * @return The evaluated value of this expression */ virtual NodeValue get_value() const = 0; - - /** - * @brief Checks if the expression's value is of a specific type. - * @tparam T The type to check for - * @return True if the value is of type T - */ - template <typename T> bool is_a() const; - - /** - * @brief Gets the expression's value as a specific type. - * @tparam T The type to retrieve - * @return The value as type T - * @throws TypeError if the value is not of type T - */ - template <typename T> T get() const; - - /** - * @brief Casts the expression's value to a specific type. - * @tparam T The type to cast to - * @return The value cast to type T - * @throws TypeError if the value cannot be cast to type T - */ - template <typename T> T cast() const; - - /** - * @brief Checks if the expression's value is numeric (int or double). - * @return True if the value is numeric - */ - bool is_numeric() const; - - /** - * @brief Checks if the expression's value is Nothing (None). - * @return True if the value is Nothing - */ - bool is_nothing() const; }; } // namespace funk diff --git a/include/ast/LiteralNode.h b/include/ast/LiteralNode.h index 9c98c67..81237fc 100644 --- a/include/ast/LiteralNode.h +++ b/include/ast/LiteralNode.h @@ -27,7 +27,7 @@ public: * @brief Evaluates the literal node. * @return Pointer to the node representing the evaluation result (self for literals) */ - Node* evaluate() override; + Node* evaluate() const override; /** * @brief Converts the literal to a string representation. diff --git a/include/ast/Node.h b/include/ast/Node.h index 20340b2..0aace72 100644 --- a/include/ast/Node.h +++ b/include/ast/Node.h @@ -4,18 +4,13 @@ */ #pragma once +#include "ast/NodeValue.h" #include "utils/Common.h" #include "utils/Exception.h" namespace funk { -/** - * @brief Variant type that can hold any of the primitive values in Funk. - * Represents the possible runtime values that a node in the AST can evaluate to. - */ -using NodeValue = std::variant<int, double, bool, char, String, None>; - /** * @brief Abstract base class for all AST nodes in the Funk language. * Defines the common interface that all AST nodes must implement, including evaluation and string representation. @@ -38,7 +33,7 @@ public: * @brief Evaluates the node and returns the result. * @return Pointer to the node representing the evaluation result */ - virtual Node* evaluate() = 0; + virtual Node* evaluate() const = 0; /** * @brief Converts the node to a string representation. @@ -53,7 +48,8 @@ public: SourceLocation get_location() const; protected: - SourceLocation location; ///< Source location where this node appears in the code + SourceLocation location; ///< Source location where this node appears in the code + mutable Node* cached_eval{nullptr}; ///< Cached evaluation result }; } // namespace funk diff --git a/include/ast/NodeValue.h b/include/ast/NodeValue.h new file mode 100644 index 0000000..93c2f19 --- /dev/null +++ b/include/ast/NodeValue.h @@ -0,0 +1,133 @@ +#pragma once + +#include "utils/Common.h" +#include "utils/Exception.h" +#include <cmath> + +namespace funk +{ + +/** + * @brief Class that wraps a variant to store any primitive value in Funk. + * Provides type checking, conversion, and operator functionality. + */ +class NodeValue +{ +public: + /** + * @brief Constructs a NodeValue with no value. + */ + NodeValue() : value(None{}) {} + /** + * @brief Constructs a NodeValue with the given value. + * @param v The value to initialize with + */ + NodeValue(int v) : value(v) {} + + /** + * @brief Constructs a NodeValue with the given value. + * @param v The value to initialize with + */ + NodeValue(double v) : value(v) {} + + /** + * @brief Constructs a NodeValue with the given value. + * @param v The value to initialize with + */ + NodeValue(bool v) : value(v) {} + + /** + * @brief Constructs a NodeValue with the given value. + * @param v The value to initialize with + */ + NodeValue(char v) : value(v) {} + + /** + * @brief Constructs a NodeValue with the given value. + * @param v The value to initialize with + */ + NodeValue(const String& v) : value(v) {} + + /** + * @brief Constructs a NodeValue with the given value. + * @param v The value to initialize with + */ + NodeValue(None v) : value(v) {} + + /** + * @brief Constructs a NodeValue with the given variant. + * @param v The variant to initialize with + */ + NodeValue(const std::variant<int, double, bool, char, String, None>& v) : value(v) {} + + /** + * @brief Checks if the expression's value is of a specific type. + * @tparam T The type to check for + * @return True if the value is of type T + */ + template <typename T> bool is_a() const; + + /** + * @brief Gets the expression's value as a specific type. + * @tparam T The type to retrieve + * @return The value as type T + * @throws TypeError if the value is not of type T + */ + template <typename T> T get() const; + + /** + * @brief Casts the expression's value to a specific type. + * @tparam T The type to cast to + * @return The value cast to type T + * @throws TypeError if the value cannot be cast to type T + */ + template <typename T> T cast() const; + + /** + * @brief Checks if the expression's value is numeric (int or double). + * @return True if the value is numeric + */ + bool is_numeric() const; + + /** + * @brief Checks if the expression's value is Nothing (None). + * @return True if the value is Nothing + */ + bool is_nothing() const; + + /** + * @brief Gets the underlying variant. + * @return Reference to the underlying variant + */ + auto& get_variant() const { return value; } + +private: + std::variant<int, double, bool, char, String, None> value; +}; + +template <typename Op> NodeValue numeric_op(const NodeValue& lhs, const NodeValue& rhs, Op op); +template <typename Op> NodeValue comparison(const NodeValue& lhs, const NodeValue& rhs, Op op); + +// Arithmetic operators +NodeValue operator+(const NodeValue& lhs, const NodeValue& rhs); +NodeValue operator-(const NodeValue& lhs, const NodeValue& rhs); +NodeValue operator*(const NodeValue& lhs, const NodeValue& rhs); +NodeValue operator/(const NodeValue& lhs, const NodeValue& rhs); +NodeValue operator%(const NodeValue& lhs, const NodeValue& rhs); + +// Comparison operators +NodeValue operator==(const NodeValue& lhs, const NodeValue& rhs); +NodeValue operator!=(const NodeValue& lhs, const NodeValue& rhs); +NodeValue operator<(const NodeValue& lhs, const NodeValue& rhs); +NodeValue operator<=(const NodeValue& lhs, const NodeValue& rhs); +NodeValue operator>(const NodeValue& lhs, const NodeValue& rhs); +NodeValue operator>=(const NodeValue& lhs, const NodeValue& rhs); + +// Logical operators +NodeValue operator&&(const NodeValue& lhs, const NodeValue& rhs); +NodeValue operator||(const NodeValue& lhs, const NodeValue& rhs); + +// Power operation +NodeValue pow(const NodeValue& lhs, const NodeValue& rhs); + +} // namespace funk diff --git a/include/token/TokenType.h b/include/token/TokenType.h index 1456a90..845bf9f 100644 --- a/include/token/TokenType.h +++ b/include/token/TokenType.h @@ -53,7 +53,8 @@ enum class TokenType MINUS, ///< Subtraction operator (-) MULTIPLY, ///< Multiplication operator (*) DIVIDE, ///< Division operator (/) - PERCENT, ///< Modulo operator (%) + MODULO, ///< Modulo operator (%) + POWER, ///< Power operator (^) // Assignment operators ASSIGN, ///< Assignment operator (=) @@ -62,6 +63,7 @@ enum class TokenType MULTIPLY_ASSIGN, ///< Multiply and assign operator (*=) DIVIDE_ASSIGN, ///< Divide and assign operator (/=) MODULO_ASSIGN, ///< Modulo and assign operator (%=) + POWER_ASSIGN, ///< Power and assign operator (^=) // Comparison operators EQUAL, ///< Equality operator (==) @@ -104,5 +106,5 @@ enum class TokenType * @param token The token type to convert * @return A string representation of the token type */ -String token_type_to_string(TokenType token); +String token_type_to_s(TokenType token); } // namespace funk diff --git a/include/utils/Exception.h b/include/utils/Exception.h index 2e081c6..40d423b 100644 --- a/include/utils/Exception.h +++ b/include/utils/Exception.h @@ -56,6 +56,7 @@ public: * @param message Description of the error */ LexerError(const SourceLocation& loc, const String& message) : FunkError(loc, "Lexer error", message) {} + LexerError(const String& message) : FunkError(SourceLocation{"", 0, 0}, "Lexer error", message) {} }; /** @@ -71,6 +72,7 @@ public: * @param message Description of the error */ SyntaxError(const SourceLocation& loc, const String& message) : FunkError(loc, "Syntax error", message) {} + SyntaxError(const String& message) : FunkError(SourceLocation{"", 0, 0}, "Syntax error", message) {} }; /** @@ -86,6 +88,7 @@ public: * @param message Description of the error */ TypeError(const SourceLocation& loc, const String& message) : FunkError(loc, "Type error", message) {} + TypeError(const String& message) : FunkError(SourceLocation{"", 0, 0}, "Type error", message) {} }; /** @@ -101,6 +104,7 @@ public: * @param message Description of the error */ RuntimeError(const SourceLocation& loc, const String& message) : FunkError(loc, "Runtime error", message) {} + RuntimeError(const String& message) : FunkError(SourceLocation{"", 0, 0}, "Runtime error", message) {} }; } // namespace funk diff --git a/source/ast/ExpressionNode.cc b/source/ast/ExpressionNode.cc index e3fd448..c704942 100644 --- a/source/ast/ExpressionNode.cc +++ b/source/ast/ExpressionNode.cc @@ -5,94 +5,4 @@ namespace funk ExpressionNode::ExpressionNode(const SourceLocation& loc) : Node(loc) {} -template <typename T> bool ExpressionNode::is_a() const -{ - return std::holds_alternative<T>(get_value()); -} - -template <typename T> T ExpressionNode::get() const -{ - if (!is_a<T>()) { throw TypeError(location, "Unexpected type"); } - return std::get<T>(get_value()); -} - -template <typename T> T ExpressionNode::cast() const -{ - if (is_a<T>()) { return get<T>(); } - if constexpr (std::is_same_v<T, None>) { throw TypeError(location, "Cannot cast 'none' to type"); } - - NodeValue value{get_value()}; - - if constexpr (std::is_same_v<T, String>) - { - if (is_a<int>()) - return std::to_string(get<int>()); - else if (is_a<double>()) - return std::to_string(get<double>()); - else if (is_a<bool>()) - return get<bool>() ? "true" : "false"; - else if (is_a<char>()) - return String(1, get<char>()); - else if (is_a<None>()) - return "none"; - } - else if constexpr (std::is_same_v<T, int>) - { - if (is_a<double>()) { return static_cast<int>(get<double>()); } - else if (is_a<bool>()) { return get<bool>() ? 1 : 0; } - else if (is_a<char>()) { return get<char>(); } - } - else if constexpr (std::is_same_v<T, double>) - { - if (is_a<int>()) { return static_cast<double>(get<int>()); } - else if (is_a<bool>()) { return get<bool>() ? 1.0 : 0.0; } - else if (is_a<char>()) { return static_cast<double>(get<char>()); } - } - else if constexpr (std::is_same_v<T, bool>) - { - if (is_a<int>()) { return get<int>() != 0; } - else if (is_a<double>()) { return get<double>() != 0.0; } - else if (is_a<char>()) { return get<char>() != '\0'; } - else if (is_a<String>()) { return !get<String>().empty(); } - else if (is_a<None>()) { return false; } - } - else if constexpr (std::is_same_v<T, char>) - { - if (is_a<int>()) { return static_cast<char>(get<int>()); } - } - - throw TypeError(location, "Cannot cast to type"); -} - -bool ExpressionNode::is_numeric() const -{ - return is_a<int>() || is_a<double>(); -} - -bool ExpressionNode::is_nothing() const -{ - return is_a<None>(); -} - -template bool ExpressionNode::is_a<int>() const; -template bool ExpressionNode::is_a<double>() const; -template bool ExpressionNode::is_a<bool>() const; -template bool ExpressionNode::is_a<char>() const; -template bool ExpressionNode::is_a<String>() const; -template bool ExpressionNode::is_a<None>() const; - -template int ExpressionNode::get<int>() const; -template double ExpressionNode::get<double>() const; -template bool ExpressionNode::get<bool>() const; -template char ExpressionNode::get<char>() const; -template String ExpressionNode::get<String>() const; -template None ExpressionNode::get<None>() const; - -template int ExpressionNode::cast<int>() const; -template double ExpressionNode::cast<double>() const; -template bool ExpressionNode::cast<bool>() const; -template char ExpressionNode::cast<char>() const; -template String ExpressionNode::cast<String>() const; -template None ExpressionNode::cast<None>() const; - } // namespace funk diff --git a/source/ast/LiteralNode.cc b/source/ast/LiteralNode.cc index c8479b2..94d2335 100644 --- a/source/ast/LiteralNode.cc +++ b/source/ast/LiteralNode.cc @@ -5,14 +5,15 @@ namespace funk LiteralNode::LiteralNode(const SourceLocation& loc, NodeValue value) : ExpressionNode(loc), value(value) {} -Node* LiteralNode::evaluate() +Node* LiteralNode::evaluate() const { - return this; + if (cached_eval) { return cached_eval; } + return cached_eval = const_cast<LiteralNode*>(this); } String LiteralNode::to_s() const { - return cast<String>(); + return value.cast<String>(); } NodeValue LiteralNode::get_value() const diff --git a/source/ast/NodeValue.cc b/source/ast/NodeValue.cc new file mode 100644 index 0000000..ecec453 --- /dev/null +++ b/source/ast/NodeValue.cc @@ -0,0 +1,240 @@ +#include "ast/NodeValue.h" + +namespace funk +{ + +template <typename T> bool NodeValue::is_a() const +{ + return std::holds_alternative<T>(value); +} + +template <typename T> T NodeValue::get() const +{ + if (!is_a<T>()) { throw TypeError("Unexpected type"); } + return std::get<T>(value); +} + +template <typename T> T NodeValue::cast() const +{ + if (is_a<T>()) { return get<T>(); } + if constexpr (std::is_same_v<T, None>) { throw TypeError("Cannot cast 'none' to type"); } + + if constexpr (std::is_same_v<T, String>) + { + if (is_a<int>()) + return std::to_string(get<int>()); + else if (is_a<double>()) + return std::to_string(get<double>()); + else if (is_a<bool>()) + return get<bool>() ? "true" : "false"; + else if (is_a<char>()) + return String(1, get<char>()); + else if (is_a<None>()) + return "none"; + } + else if constexpr (std::is_same_v<T, int>) + { + if (is_a<double>()) { return static_cast<int>(get<double>()); } + else if (is_a<bool>()) { return get<bool>() ? 1 : 0; } + else if (is_a<char>()) { return get<char>(); } + } + else if constexpr (std::is_same_v<T, double>) + { + if (is_a<int>()) { return static_cast<double>(get<int>()); } + else if (is_a<bool>()) { return get<bool>() ? 1.0 : 0.0; } + else if (is_a<char>()) { return static_cast<double>(get<char>()); } + } + else if constexpr (std::is_same_v<T, bool>) + { + if (is_a<int>()) { return get<int>() != 0; } + else if (is_a<double>()) { return get<double>() != 0.0; } + else if (is_a<char>()) { return get<char>() != '\0'; } + else if (is_a<String>()) { return !get<String>().empty(); } + else if (is_a<None>()) { return false; } + } + else if constexpr (std::is_same_v<T, char>) + { + if (is_a<int>()) { return static_cast<char>(get<int>()); } + } + + throw TypeError("Cannot cast to type"); +} + +bool NodeValue::is_numeric() const +{ + return is_a<int>() || is_a<double>(); +} + +bool NodeValue::is_nothing() const +{ + return is_a<None>(); +} + +template <typename Op> NodeValue numeric_op(const NodeValue& lhs, const NodeValue& rhs, Op op) +{ + if (!lhs.is_numeric() || !rhs.is_numeric()) + { + throw TypeError("Cannot perform arithmetic operation on " + lhs.cast<String>() + " and " + rhs.cast<String>()); + } + + if (lhs.is_a<int>() && rhs.is_a<int>()) { return op(lhs.get<int>(), rhs.get<int>()); } + else { return op(lhs.cast<double>(), rhs.cast<double>()); } +} + +template <typename Op> NodeValue comparison(const NodeValue& lhs, const NodeValue& rhs, Op op) +{ + if (lhs.is_a<int>() && rhs.is_a<int>()) { return op(lhs.get<int>(), rhs.get<int>()); } + else if (lhs.is_a<double>() && rhs.is_a<double>()) { return op(lhs.get<double>(), rhs.get<double>()); } + else if (lhs.is_a<bool>() && rhs.is_a<bool>()) { return op(lhs.get<bool>(), rhs.get<bool>()); } + else if (lhs.is_a<char>() && rhs.is_a<char>()) { return op(lhs.get<char>(), rhs.get<char>()); } + else if (lhs.is_a<String>() && rhs.is_a<String>()) { return op(lhs.get<String>(), rhs.get<String>()); } + else if (lhs.is_numeric() && rhs.is_numeric()) { return op(lhs.cast<double>(), rhs.cast<double>()); } + + throw TypeError("Cannot compare values of types " + lhs.cast<String>() + " and " + rhs.cast<String>()); +} + +NodeValue operator+(const NodeValue& lhs, const NodeValue& rhs) +{ + if (lhs.is_a<String>() && rhs.is_a<String>()) { return lhs.get<String>() + rhs.get<String>(); } + + return numeric_op(lhs, rhs, [](auto a, auto b) + { + return a + b; + }); +} + +NodeValue operator-(const NodeValue& lhs, const NodeValue& rhs) +{ + return numeric_op(lhs, rhs, [](auto a, auto b) + { + return a - b; + }); +} + +NodeValue operator*(const NodeValue& lhs, const NodeValue& rhs) +{ + return numeric_op(lhs, rhs, [](auto a, auto b) + { + return a * b; + }); +} + +NodeValue operator/(const NodeValue& lhs, const NodeValue& rhs) +{ + if ((rhs.is_a<int>() && rhs.get<int>() == 0) || (rhs.is_a<double>() && rhs.get<double>() == 0.0)) + { + throw RuntimeError("Division by zero"); + } + + return numeric_op(lhs, rhs, [](auto a, auto b) + { + return a / b; + }); +} + +NodeValue operator%(const NodeValue& lhs, const NodeValue& rhs) +{ + if (lhs.is_a<int>() && rhs.is_a<int>()) + { + if (rhs.get<int>() == 0) { throw RuntimeError("Modulo by zero"); } + return lhs.get<int>() % rhs.get<int>(); + } + + throw TypeError("Modulo operation requires integer operands"); +} + +NodeValue operator==(const NodeValue& lhs, const NodeValue& rhs) +{ + if (lhs.is_a<None>() || rhs.is_a<None>()) { return lhs.is_a<None>() && rhs.is_a<None>(); } + return comparison(lhs, rhs, [](auto a, auto b) + { + return a == b; + }); +} + +NodeValue operator!=(const NodeValue& lhs, const NodeValue& rhs) +{ + return comparison(lhs, rhs, [](auto a, auto b) + { + return a != b; + }); +} + +NodeValue operator<(const NodeValue& lhs, const NodeValue& rhs) +{ + return comparison(lhs, rhs, [](auto a, auto b) + { + return a < b; + }); +} + +NodeValue operator<=(const NodeValue& lhs, const NodeValue& rhs) +{ + return comparison(lhs, rhs, [](auto a, auto b) + { + return a <= b; + }); +} + +NodeValue operator>(const NodeValue& lhs, const NodeValue& rhs) +{ + return comparison(lhs, rhs, [](auto a, auto b) + { + return a > b; + }); +} + +NodeValue operator>=(const NodeValue& lhs, const NodeValue& rhs) +{ + return comparison(lhs, rhs, [](auto a, auto b) + { + return a >= b; + }); +} + +NodeValue operator&&(const NodeValue& lhs, const NodeValue& rhs) +{ + return comparison(lhs, rhs, [](auto a, auto b) + { + return a && b; + }); +} + +NodeValue operator||(const NodeValue& lhs, const NodeValue& rhs) +{ + return comparison(lhs, rhs, [](auto a, auto b) + { + return a || b; + }); +} + +NodeValue pow(const NodeValue& lhs, const NodeValue& rhs) +{ + if (!lhs.is_numeric() || !rhs.is_numeric()) { throw TypeError("Cannot raise non-numeric value to a power"); } + + if (lhs.is_a<int>() && rhs.is_a<int>()) { return static_cast<int>(std::pow(lhs.get<int>(), rhs.get<int>())); } + else { return std::pow(lhs.cast<double>(), rhs.cast<double>()); } +} + +template bool NodeValue::is_a<int>() const; +template bool NodeValue::is_a<double>() const; +template bool NodeValue::is_a<bool>() const; +template bool NodeValue::is_a<char>() const; +template bool NodeValue::is_a<String>() const; +template bool NodeValue::is_a<None>() const; + +template int NodeValue::get<int>() const; +template double NodeValue::get<double>() const; +template bool NodeValue::get<bool>() const; +template char NodeValue::get<char>() const; +template String NodeValue::get<String>() const; +template None NodeValue::get<None>() const; + +template int NodeValue::cast<int>() const; +template double NodeValue::cast<double>() const; +template bool NodeValue::cast<bool>() const; +template char NodeValue::cast<char>() const; +template String NodeValue::cast<String>() const; +template None NodeValue::cast<None>() const; + +} // namespace funk diff --git a/source/lexer/Lexer.cc b/source/lexer/Lexer.cc index fc7b4bf..2e0c112 100644 --- a/source/lexer/Lexer.cc +++ b/source/lexer/Lexer.cc @@ -64,7 +64,8 @@ Token Lexer::next_token() case ',': return make_token(lexeme, TokenType::COMMA); case '.': return make_token(lexeme, TokenType::DOT); case ';': return make_token(lexeme, TokenType::SEMICOLON); - case '%': return make_token(lexeme, TokenType::PERCENT); + case '%': return make_token(lexeme, TokenType::MODULO); + case '^': return make_token(lexeme, TokenType::POWER); case '+': if (match('=')) { return make_token(lexeme + '=', TokenType::PLUS_ASSIGN); } diff --git a/source/token/Token.cc b/source/token/Token.cc index cd57574..fe463b8 100644 --- a/source/token/Token.cc +++ b/source/token/Token.cc @@ -51,7 +51,7 @@ String Token::to_s() const { std::ostringstream oss; oss << "Token("; - oss << "type=" << token_type_to_string(type) << ", "; + oss << "type=" << token_type_to_s(type) << ", "; oss << "lexeme=\"" << lexeme << "\", "; // Include value if it's not monostate diff --git a/source/token/TokenType.cc b/source/token/TokenType.cc index 56b867a..54dfbfd 100644 --- a/source/token/TokenType.cc +++ b/source/token/TokenType.cc @@ -3,7 +3,7 @@ namespace funk { -String token_type_to_string(TokenType token) +String token_type_to_s(TokenType token) { switch (token) { @@ -36,7 +36,8 @@ String token_type_to_string(TokenType token) case TokenType::MINUS: return "-"; case TokenType::MULTIPLY: return "*"; case TokenType::DIVIDE: return "/"; - case TokenType::PERCENT: return "%"; + case TokenType::MODULO: return "%"; + case TokenType::POWER: return "^"; case TokenType::ASSIGN: return "="; case TokenType::PLUS_ASSIGN: return "+="; @@ -44,6 +45,7 @@ String token_type_to_string(TokenType token) case TokenType::MULTIPLY_ASSIGN: return "*="; case TokenType::DIVIDE_ASSIGN: return "/="; case TokenType::MODULO_ASSIGN: return "%="; + case TokenType::POWER_ASSIGN: return "^="; case TokenType::EQUAL: return "=="; case TokenType::NOT_EQUAL: return "!="; diff --git a/tests/TestLiteralNode.cc b/tests/TestLiteralNode.cc index 457b7f6..cda9b68 100644 --- a/tests/TestLiteralNode.cc +++ b/tests/TestLiteralNode.cc @@ -23,17 +23,17 @@ TEST_F(TestLiteralNode, Integer) SourceLocation loc{"test.funk", 0, 0}; LiteralNode node{loc, 5}; - ASSERT_TRUE(node.is_numeric()); - ASSERT_TRUE(node.is_a<int>()); - ASSERT_FALSE(node.is_a<double>()); - ASSERT_FALSE(node.is_nothing()); - ASSERT_EQ(node.get<int>(), 5); - ASSERT_THROW(node.get<double>(), TypeError); - ASSERT_THROW(node.cast<None>(), TypeError); - ASSERT_EQ(node.cast<double>(), 5.0); - ASSERT_EQ(node.cast<bool>(), true); - ASSERT_EQ(node.cast<char>(), static_cast<char>(5)); - ASSERT_EQ(node.cast<String>(), "5"); + ASSERT_TRUE(node.get_value().is_numeric()); + ASSERT_TRUE(node.get_value().is_a<int>()); + ASSERT_FALSE(node.get_value().is_a<double>()); + ASSERT_FALSE(node.get_value().is_nothing()); + ASSERT_EQ(node.get_value().get<int>(), 5); + ASSERT_THROW(node.get_value().get<double>(), TypeError); + ASSERT_THROW(node.get_value().cast<None>(), TypeError); + ASSERT_EQ(node.get_value().cast<double>(), 5.0); + ASSERT_EQ(node.get_value().cast<bool>(), true); + ASSERT_EQ(node.get_value().cast<char>(), static_cast<char>(5)); + ASSERT_EQ(node.get_value().cast<String>(), "5"); } TEST_F(TestLiteralNode, Double) @@ -41,35 +41,35 @@ TEST_F(TestLiteralNode, Double) SourceLocation loc{"test.funk", 0, 0}; LiteralNode node{loc, 3.14}; - ASSERT_TRUE(node.is_numeric()); - ASSERT_FALSE(node.is_a<int>()); - ASSERT_TRUE(node.is_a<double>()); - ASSERT_FALSE(node.is_nothing()); - ASSERT_EQ(node.get<double>(), 3.14); - ASSERT_THROW(node.get<int>(), TypeError); - ASSERT_EQ(node.cast<int>(), 3); - ASSERT_EQ(node.cast<bool>(), true); - ASSERT_THROW(node.cast<char>(), TypeError); + ASSERT_TRUE(node.get_value().is_numeric()); + ASSERT_FALSE(node.get_value().is_a<int>()); + ASSERT_TRUE(node.get_value().is_a<double>()); + ASSERT_FALSE(node.get_value().is_nothing()); + ASSERT_EQ(node.get_value().get<double>(), 3.14); + ASSERT_THROW(node.get_value().get<int>(), TypeError); + ASSERT_EQ(node.get_value().cast<int>(), 3); + ASSERT_EQ(node.get_value().cast<bool>(), true); + ASSERT_THROW(node.get_value().cast<char>(), TypeError); } TEST_F(TestLiteralNode, Boolean) { SourceLocation loc{"test.funk", 0, 0}; - LiteralNode trueNode{loc, true}; - ASSERT_FALSE(trueNode.is_numeric()); - ASSERT_TRUE(trueNode.is_a<bool>()); - ASSERT_FALSE(trueNode.is_nothing()); - ASSERT_EQ(trueNode.get<bool>(), true); - ASSERT_EQ(trueNode.cast<int>(), 1); - ASSERT_EQ(trueNode.cast<double>(), 1.0); - ASSERT_EQ(trueNode.cast<String>(), "true"); - - LiteralNode falseNode{loc, false}; - ASSERT_EQ(falseNode.get<bool>(), false); - ASSERT_EQ(falseNode.cast<int>(), 0); - ASSERT_EQ(falseNode.cast<double>(), 0.0); - ASSERT_EQ(falseNode.cast<String>(), "false"); + LiteralNode true_node{loc, true}; + ASSERT_FALSE(true_node.get_value().is_numeric()); + ASSERT_TRUE(true_node.get_value().is_a<bool>()); + ASSERT_FALSE(true_node.get_value().is_nothing()); + ASSERT_EQ(true_node.get_value().get<bool>(), true); + ASSERT_EQ(true_node.get_value().cast<int>(), 1); + ASSERT_EQ(true_node.get_value().cast<double>(), 1.0); + ASSERT_EQ(true_node.get_value().cast<String>(), "true"); + + LiteralNode false_node{loc, false}; + ASSERT_EQ(false_node.get_value().get<bool>(), false); + ASSERT_EQ(false_node.get_value().cast<int>(), 0); + ASSERT_EQ(false_node.get_value().cast<double>(), 0.0); + ASSERT_EQ(false_node.get_value().cast<String>(), "false"); } TEST_F(TestLiteralNode, String) @@ -77,15 +77,15 @@ TEST_F(TestLiteralNode, String) SourceLocation loc{"test.funk", 0, 0}; LiteralNode node{loc, String("hello")}; - ASSERT_FALSE(node.is_numeric()); - ASSERT_TRUE(node.is_a<String>()); - ASSERT_FALSE(node.is_nothing()); - ASSERT_EQ(node.get<String>(), "hello"); + ASSERT_FALSE(node.get_value().is_numeric()); + ASSERT_TRUE(node.get_value().is_a<String>()); + ASSERT_FALSE(node.get_value().is_nothing()); + ASSERT_EQ(node.get_value().get<String>(), "hello"); ASSERT_EQ(node.to_s(), "hello"); - ASSERT_EQ(node.cast<bool>(), true); - ASSERT_THROW(node.cast<int>(), TypeError); - ASSERT_THROW(node.cast<double>(), TypeError); - ASSERT_THROW(node.cast<char>(), TypeError); + ASSERT_EQ(node.get_value().cast<bool>(), true); + ASSERT_THROW(node.get_value().cast<int>(), TypeError); + ASSERT_THROW(node.get_value().cast<double>(), TypeError); + ASSERT_THROW(node.get_value().cast<char>(), TypeError); } TEST_F(TestLiteralNode, Char) @@ -93,14 +93,14 @@ TEST_F(TestLiteralNode, Char) SourceLocation loc{"test.funk", 0, 0}; LiteralNode node{loc, 'A'}; - ASSERT_FALSE(node.is_numeric()); - ASSERT_TRUE(node.is_a<char>()); - ASSERT_FALSE(node.is_nothing()); - ASSERT_EQ(node.get<char>(), 'A'); - ASSERT_EQ(node.cast<int>(), 65); - ASSERT_EQ(node.cast<double>(), 65.0); - ASSERT_EQ(node.cast<bool>(), true); - ASSERT_EQ(node.cast<String>(), "A"); + ASSERT_FALSE(node.get_value().is_numeric()); + ASSERT_TRUE(node.get_value().is_a<char>()); + ASSERT_FALSE(node.get_value().is_nothing()); + ASSERT_EQ(node.get_value().get<char>(), 'A'); + ASSERT_EQ(node.get_value().cast<int>(), 65); + ASSERT_EQ(node.get_value().cast<double>(), 65.0); + ASSERT_EQ(node.get_value().cast<bool>(), true); + ASSERT_EQ(node.get_value().cast<String>(), "A"); } TEST_F(TestLiteralNode, None) @@ -108,30 +108,30 @@ TEST_F(TestLiteralNode, None) SourceLocation loc{"test.funk", 0, 0}; LiteralNode node{loc, None{}}; - ASSERT_FALSE(node.is_numeric()); - ASSERT_TRUE(node.is_a<None>()); - ASSERT_TRUE(node.is_nothing()); - ASSERT_NO_THROW(node.get<None>()); - ASSERT_EQ(node.cast<bool>(), false); - ASSERT_EQ(node.cast<String>(), "none"); - ASSERT_THROW(node.cast<int>(), TypeError); - ASSERT_THROW(node.cast<double>(), TypeError); - ASSERT_THROW(node.cast<char>(), TypeError); + ASSERT_FALSE(node.get_value().is_numeric()); + ASSERT_TRUE(node.get_value().is_a<None>()); + ASSERT_TRUE(node.get_value().is_nothing()); + ASSERT_NO_THROW(node.get_value().get<None>()); + ASSERT_EQ(node.get_value().cast<bool>(), false); + ASSERT_EQ(node.get_value().cast<String>(), "none"); + ASSERT_THROW(node.get_value().cast<int>(), TypeError); + ASSERT_THROW(node.get_value().cast<double>(), TypeError); + ASSERT_THROW(node.get_value().cast<char>(), TypeError); } TEST_F(TestLiteralNode, GetValue) { SourceLocation loc{"test.funk", 0, 0}; - LiteralNode intNode{loc, 42}; - NodeValue value = intNode.get_value(); - ASSERT_TRUE(std::holds_alternative<int>(value)); - ASSERT_EQ(std::get<int>(value), 42); + LiteralNode int_node{loc, 42}; + NodeValue value = int_node.get_value(); + ASSERT_TRUE(std::holds_alternative<int>(value.get_variant())); + ASSERT_EQ(std::get<int>(value.get_variant()), 42); LiteralNode stringNode{loc, String("test")}; value = stringNode.get_value(); - ASSERT_TRUE(std::holds_alternative<String>(value)); - ASSERT_EQ(std::get<String>(value), "test"); + ASSERT_TRUE(std::holds_alternative<String>(value.get_variant())); + ASSERT_EQ(std::get<String>(value.get_variant()), "test"); } TEST_F(TestLiteralNode, Evaluate) diff --git a/tests/TestNodeValue.cc b/tests/TestNodeValue.cc new file mode 100644 index 0000000..70e01ec --- /dev/null +++ b/tests/TestNodeValue.cc @@ -0,0 +1,213 @@ +#include "ast/NodeValue.h" +#include "utils/Common.h" +#include <gtest/gtest.h> + +using namespace funk; + +class TestNodeValue : public ::testing::Test +{ +protected: + void SetUp() override + { + // Setup code if needed + } + + void TearDown() override + { + // Cleanup code if needed + } +}; + +TEST_F(TestNodeValue, Integer) +{ + NodeValue value{5}; + ASSERT_TRUE(value.is_numeric()); + ASSERT_TRUE(value.is_a<int>()); + ASSERT_FALSE(value.is_a<double>()); + ASSERT_FALSE(value.is_nothing()); + ASSERT_EQ(value.get<int>(), 5); + ASSERT_THROW(value.get<double>(), TypeError); + ASSERT_THROW(value.cast<None>(), TypeError); + ASSERT_EQ(value.cast<double>(), 5.0); + ASSERT_EQ(value.cast<bool>(), true); + ASSERT_EQ(value.cast<char>(), static_cast<char>(5)); + ASSERT_EQ(value.cast<String>(), "5"); +} + +TEST_F(TestNodeValue, Double) +{ + NodeValue value{3.14}; + ASSERT_TRUE(value.is_numeric()); + ASSERT_FALSE(value.is_a<int>()); + ASSERT_TRUE(value.is_a<double>()); + ASSERT_FALSE(value.is_nothing()); + ASSERT_EQ(value.get<double>(), 3.14); + ASSERT_THROW(value.get<int>(), TypeError); + ASSERT_EQ(value.cast<int>(), 3); + ASSERT_EQ(value.cast<bool>(), true); + ASSERT_THROW(value.cast<char>(), TypeError); +} + +TEST_F(TestNodeValue, Boolean) +{ + NodeValue true_value{true}; + ASSERT_FALSE(true_value.is_numeric()); + ASSERT_TRUE(true_value.is_a<bool>()); + ASSERT_FALSE(true_value.is_nothing()); + ASSERT_EQ(true_value.get<bool>(), true); + ASSERT_EQ(true_value.cast<int>(), 1); + ASSERT_EQ(true_value.cast<double>(), 1.0); + ASSERT_EQ(true_value.cast<String>(), "true"); + + NodeValue false_value{false}; + ASSERT_EQ(false_value.get<bool>(), false); + ASSERT_EQ(false_value.cast<int>(), 0); + ASSERT_EQ(false_value.cast<double>(), 0.0); + ASSERT_EQ(false_value.cast<String>(), "false"); +} + +TEST_F(TestNodeValue, String) +{ + NodeValue value{String("hello")}; + ASSERT_FALSE(value.is_numeric()); + ASSERT_TRUE(value.is_a<String>()); + ASSERT_FALSE(value.is_nothing()); + ASSERT_EQ(value.get<String>(), "hello"); + ASSERT_EQ(value.cast<bool>(), true); + ASSERT_THROW(value.cast<int>(), TypeError); + ASSERT_THROW(value.cast<double>(), TypeError); + ASSERT_THROW(value.cast<char>(), TypeError); +} + +TEST_F(TestNodeValue, Char) +{ + NodeValue value{'A'}; + ASSERT_FALSE(value.is_numeric()); + ASSERT_TRUE(value.is_a<char>()); + ASSERT_FALSE(value.is_nothing()); + ASSERT_EQ(value.get<char>(), 'A'); + ASSERT_EQ(value.cast<int>(), 65); + ASSERT_EQ(value.cast<double>(), 65.0); + ASSERT_EQ(value.cast<bool>(), true); + ASSERT_EQ(value.cast<String>(), "A"); +} + +TEST_F(TestNodeValue, None) +{ + NodeValue value{None{}}; + ASSERT_FALSE(value.is_numeric()); + ASSERT_TRUE(value.is_a<None>()); + ASSERT_TRUE(value.is_nothing()); + ASSERT_NO_THROW(value.get<None>()); + ASSERT_EQ(value.cast<bool>(), false); + ASSERT_EQ(value.cast<String>(), "none"); + ASSERT_THROW(value.cast<int>(), TypeError); + ASSERT_THROW(value.cast<double>(), TypeError); + ASSERT_THROW(value.cast<char>(), TypeError); +} + +TEST_F(TestNodeValue, GetVariant) +{ + NodeValue int_value{42}; + auto variant = int_value.get_variant(); + ASSERT_TRUE(std::holds_alternative<int>(variant)); + ASSERT_EQ(std::get<int>(variant), 42); + + NodeValue string_value{String("test")}; + variant = string_value.get_variant(); + ASSERT_TRUE(std::holds_alternative<String>(variant)); + ASSERT_EQ(std::get<String>(variant), "test"); +} + +TEST_F(TestNodeValue, ArithmeticOperations) +{ + NodeValue a{5}; + NodeValue b{3}; + + NodeValue sum = a + b; + ASSERT_TRUE(sum.is_a<int>()); + ASSERT_EQ(sum.get<int>(), 8); + + NodeValue diff = a - b; + ASSERT_TRUE(diff.is_a<int>()); + ASSERT_EQ(diff.get<int>(), 2); + + NodeValue product = a * b; + ASSERT_TRUE(product.is_a<int>()); + ASSERT_EQ(product.get<int>(), 15); + + NodeValue quotient = a / b; + ASSERT_TRUE(quotient.is_a<int>()); + ASSERT_EQ(quotient.get<int>(), 1); + + NodeValue modulo = a % b; + ASSERT_TRUE(modulo.is_a<int>()); + ASSERT_EQ(modulo.get<int>(), 2); + + NodeValue power = pow(a, b); + ASSERT_TRUE(power.is_a<int>()); + ASSERT_EQ(power.get<int>(), 125); +} + +TEST_F(TestNodeValue, ComparisonOperations) +{ + NodeValue a{5}; + NodeValue b{3}; + NodeValue c{5}; + + NodeValue eq1 = (a == b); + ASSERT_TRUE(eq1.is_a<bool>()); + ASSERT_EQ(eq1.get<bool>(), false); + + NodeValue eq2 = (a == c); + ASSERT_TRUE(eq2.is_a<bool>()); + ASSERT_EQ(eq2.get<bool>(), true); + + NodeValue neq = (a != b); + ASSERT_TRUE(neq.is_a<bool>()); + ASSERT_EQ(neq.get<bool>(), true); + + NodeValue gt = (a > b); + ASSERT_TRUE(gt.is_a<bool>()); + ASSERT_EQ(gt.get<bool>(), true); + + NodeValue lt = (a < b); + ASSERT_TRUE(lt.is_a<bool>()); + ASSERT_EQ(lt.get<bool>(), false); + + NodeValue ge = (a >= c); + ASSERT_TRUE(ge.is_a<bool>()); + ASSERT_EQ(ge.get<bool>(), true); + + NodeValue le = (a <= c); + ASSERT_TRUE(le.is_a<bool>()); + ASSERT_EQ(le.get<bool>(), true); +} + +TEST_F(TestNodeValue, LogicalOperations) +{ + NodeValue t{true}; + NodeValue f{false}; + + NodeValue and_result = (t && t); + ASSERT_TRUE(and_result.is_a<bool>()); + ASSERT_EQ(and_result.get<bool>(), true); + + NodeValue and_result2 = (t && f); + ASSERT_TRUE(and_result2.is_a<bool>()); + ASSERT_EQ(and_result2.get<bool>(), false); + + NodeValue or_result = (t || f); + ASSERT_TRUE(or_result.is_a<bool>()); + ASSERT_EQ(or_result.get<bool>(), true); + + NodeValue or_result2 = (f || f); + ASSERT_TRUE(or_result2.is_a<bool>()); + ASSERT_EQ(or_result2.get<bool>(), false); +} + +int main(int argc, char** argv) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} -- GitLab