diff --git a/examples/fib.funk b/examples/fib.funk index 4f6e6fc6ad5e9ed7c60b46a1a2985d78e0e4f595..357b69814c64f00469515fc755478f78bf7a54f1 100644 --- a/examples/fib.funk +++ b/examples/fib.funk @@ -1,13 +1,5 @@ -funk fib = (0) { return 0 }; -funk fib = (1) { return 1 }; -funk fib = (numb n) { return fib(n - 1) + fib(n - 2) }; +funk fib = (0) { return 0; }; +funk fib = (1) { return 1; }; +funk fib = (numb n) { return fib(n - 1) + fib(n - 2); }; -funk main = { - mut numb i = 0; - while (i < 100) { - fib(i) >> print; - i += 1; - } -}; - -main(); +20 >> fib >> print; diff --git a/include/ast/BlockNode.h b/include/ast/BlockNode.h index 7879f9038b9a7464818f1149128014c0c29f0eda..1fb8bbb77d6f0bcb2e2143302b7ba2b6728f36cd 100644 --- a/include/ast/BlockNode.h +++ b/include/ast/BlockNode.h @@ -1,6 +1,7 @@ #pragma once #include "ast/Node.h" +#include "ast/control/ReturnNode.h" #include "parser/Scope.h" namespace funk diff --git a/include/ast/control/ReturnNode.h b/include/ast/control/ReturnNode.h new file mode 100644 index 0000000000000000000000000000000000000000..71217a2fa24f46cc1d9613c650604d732c027ddf --- /dev/null +++ b/include/ast/control/ReturnNode.h @@ -0,0 +1,24 @@ +#pragma once + +#include "ast/control/ControlNode.h" +#include "ast/expression/ExpressionNode.h" + +namespace funk +{ + +class ReturnNode : public ControlNode +{ +public: + ReturnNode(const SourceLocation& location, ExpressionNode* value); + ~ReturnNode() override; + + Node* evaluate() const override; + String to_s() const override; + + ExpressionNode* get_value() const; + +private: + ExpressionNode* value; +}; + +} // namespace funk diff --git a/include/ast/declaration/FunctionNode.h b/include/ast/declaration/FunctionNode.h new file mode 100644 index 0000000000000000000000000000000000000000..023254ae270bfa6b7b9fe1d469edbea7d87bf490 --- /dev/null +++ b/include/ast/declaration/FunctionNode.h @@ -0,0 +1,47 @@ +#pragma once + +#include "ast/BlockNode.h" +#include "ast/declaration/VariableNode.h" +#include "ast/expression/ExpressionNode.h" +#include "parser/Registry.h" +#include "parser/Scope.h" +#include "token/Token.h" + +namespace funk +{ + +class FunctionNode : public Node +{ +public: + FunctionNode(const SourceLocation& location, bool is_mutable, const String& identifier, + const Vector<Pair<TokenType, String>>& parameters, BlockNode* body); + + FunctionNode(const SourceLocation& location, bool is_mutable, const String& identifier, + const Vector<ExpressionNode*>& pattern_values, BlockNode* body); + + ~FunctionNode() override; + + Node* evaluate() const override; + String to_s() const override; + + Node* call(const Vector<ExpressionNode*>& arguments) const; + + bool is_mutable_function() const; + String get_identifier() const; + const Vector<Pair<TokenType, String>>& get_parameters() const; + BlockNode* get_body() const; + + bool is_pattern_matching() const; + bool matches(const Vector<ExpressionNode*>& arguments) const; + +private: + bool is_mutable; + bool is_pattern; + String identifier; + Vector<Pair<TokenType, String>> parameters; + Vector<ExpressionNode*> pattern_values; + BlockNode* body; + + void init_param_scope(const Vector<ExpressionNode*>& arguments) const; +}; +} // namespace funk diff --git a/include/ast/expression/CallNode.h b/include/ast/expression/CallNode.h index 0cf21d0c3440c209c4290510e31454c112f78722..5ff434082a6ad62ee7b9771c65681b7a24580181 100644 --- a/include/ast/expression/CallNode.h +++ b/include/ast/expression/CallNode.h @@ -1,9 +1,12 @@ #pragma once +#include "ast/declaration/FunctionNode.h" #include "ast/expression/ExpressionNode.h" #include "ast/expression/LiteralNode.h" #include "logging/LogMacros.h" #include "parser/BuiltIn.h" +#include "parser/Registry.h" +#include "parser/Scope.h" #include "token/Token.h" namespace funk @@ -19,6 +22,8 @@ public: String to_s() const override; NodeValue get_value() const override; + const Token& get_identifier() const; + const Vector<ExpressionNode*>& get_args() const; protected: Token identifier; diff --git a/include/ast/expression/PipeNode.h b/include/ast/expression/PipeNode.h new file mode 100644 index 0000000000000000000000000000000000000000..6fdc0a8fedc1829685dfd2dbefc53c12cfdc99b8 --- /dev/null +++ b/include/ast/expression/PipeNode.h @@ -0,0 +1,27 @@ +#pragma once + +#include "ast/declaration/FunctionNode.h" +#include "ast/expression/CallNode.h" +#include "ast/expression/ExpressionNode.h" + +namespace funk +{ + +class PipeNode : public ExpressionNode +{ +public: + PipeNode(const SourceLocation& location, ExpressionNode* source, ExpressionNode* target); + ~PipeNode() override; + + Node* evaluate() const override; + String to_s() const override; + NodeValue get_value() const override; + + ExpressionNode* get_source() const; + ExpressionNode* get_target() const; + +private: + ExpressionNode* source; + ExpressionNode* target; +}; +} // namespace funk diff --git a/include/parser/Parser.h b/include/parser/Parser.h index bb44e6ea3a8c878ba0837a852773f74881240c57..f7231372eefb023800cc1fcc165668443700d22f 100644 --- a/include/parser/Parser.h +++ b/include/parser/Parser.h @@ -17,6 +17,7 @@ #include "ast/expression/ListNode.h" #include "ast/expression/LiteralNode.h" #include "ast/expression/MethodCallNode.h" +#include "ast/expression/PipeNode.h" #include "ast/expression/UnaryOpNode.h" #include "lexer/Lexer.h" #include "logging/LogMacros.h" @@ -109,6 +110,13 @@ private: */ bool check(TokenType type) const; + /** + * @brief Checks if the next token is of the expected type + * @param type The token type to check for + * @return bool True if the next token matches the expected type + */ + bool check_next(TokenType type) const; + /** * @brief Checks if the current token matches the expected type, and advances if it does * @param expected The token type to check for @@ -128,6 +136,18 @@ private: */ Node* parse_declaration(); + /** + * @brief Parses a variable declaration + * @return Node* The AST node representing the variable declaration + */ + Node* parse_variable_declaration(bool is_mutable); + + /** + * @brief Parses a function declaration + * @return Node* The AST node representing the function declaration + */ + Node* parse_function_declaration(bool is_mutable); + /** * @brief Parses a block of statements * @return Node* The AST node representing the block @@ -152,6 +172,12 @@ private: */ Node* parse_while(); + /** + * @brief Parses a return statement + * @return Node* The AST node representing the return statement + */ + Node* parse_return(); + /** * @brief Parses an expression * @return Node* The AST node representing the expression @@ -164,6 +190,12 @@ private: */ Node* parse_assignment(); + /** + * @brief Parses a pipe expression + * @return Node* The AST node representing the pipe expression + */ + Node* parse_pipe(); + /** * @brief Parses a logical OR expression * @return Node* The AST node representing the logical OR expression diff --git a/include/parser/Registry.h b/include/parser/Registry.h new file mode 100644 index 0000000000000000000000000000000000000000..dc8d4465c1fe9b32914e0ec6a87809b11aff61cb --- /dev/null +++ b/include/parser/Registry.h @@ -0,0 +1,30 @@ +#pragma once + +#include "ast/declaration/FunctionNode.h" + +namespace funk +{ + +/** + * @brief Forward declaration of FunctionNode. + */ +class FunctionNode; + +class Registry +{ +public: + static Registry& instance(); + + bool add_function(FunctionNode* node); + FunctionNode* get_function(const String& identifier, const Vector<ExpressionNode*>& arguments) const; + void remove_function(const String& identifier); + bool contains(const String& identifier) const; + +private: + Registry() = default; + ~Registry() = default; + + HashMap<String, Vector<FunctionNode*>> functions; +}; + +} // namespace funk diff --git a/include/parser/Scope.h b/include/parser/Scope.h index 91ab262ab536a5b5db9b299c361428c2945511dc..b02738e9ad58fc32e932ca48911e88de97953c1e 100644 --- a/include/parser/Scope.h +++ b/include/parser/Scope.h @@ -17,6 +17,7 @@ public: void add(const String& name, Node* node); Node* get(const String& name) const; + bool contains(const String& name) const; private: Scope(); diff --git a/include/utils/Common.h b/include/utils/Common.h index 2cfde8a494e978f3908fc09e932b40cf2233727a..6715fde1fdd33855dbbddc1f59225d0f5e52f538 100644 --- a/include/utils/Common.h +++ b/include/utils/Common.h @@ -49,6 +49,13 @@ template <typename T> using Vector = std::vector<T>; */ template <typename K, typename V> using HashMap = std::unordered_map<K, V>; +/** + * @brief Template alias for std::pair. + * @tparam K The key type + * @tparam V The value type + */ +template <typename K, typename V> using Pair = std::pair<K, V>; + /** * @brief Macro for std::to_string. */ diff --git a/source/ast/BlockNode.cc b/source/ast/BlockNode.cc index 8fd7357d78b73393af4cfdc26077720b216dfb10..e945c01af71bdd5de88f1ff7cd7734ae42c971cd 100644 --- a/source/ast/BlockNode.cc +++ b/source/ast/BlockNode.cc @@ -13,7 +13,7 @@ BlockNode::~BlockNode() void BlockNode::add(Node* statement) { - if (statement == nullptr) + if (!statement) { LOG_WARN("Attempted to add null statement to block"); return; @@ -30,7 +30,12 @@ Node* BlockNode::evaluate() const { Scope::instance().push(); Node* result{}; - for (Node* statement : statements) { result = statement->evaluate(); } + for (Node* statement : statements) + { + result = statement->evaluate(); + if (dynamic_cast<ReturnNode*>(statement)) { break; } + result = nullptr; + } Scope::instance().pop(); return result; } diff --git a/source/ast/control/ReturnNode.cc b/source/ast/control/ReturnNode.cc new file mode 100644 index 0000000000000000000000000000000000000000..67825cb76c3ba372ae4a64da7eeb9e9bd10d0581 --- /dev/null +++ b/source/ast/control/ReturnNode.cc @@ -0,0 +1,30 @@ +#include "ast/control/ReturnNode.h" +#include "utils/Common.h" + +namespace funk +{ + +ReturnNode::ReturnNode(const SourceLocation& location, ExpressionNode* value) : ControlNode(location), value(value) {} + +ReturnNode::~ReturnNode() +{ + delete value; +} + +Node* ReturnNode::evaluate() const +{ + return value ? value->evaluate() : nullptr; +} + +String ReturnNode::to_s() const +{ + if (!value) { return "return"; } + return "return " + value->to_s(); +} + +ExpressionNode* ReturnNode::get_value() const +{ + return value; +} + +} // namespace funk diff --git a/source/ast/declaration/FunctionNode.cc b/source/ast/declaration/FunctionNode.cc new file mode 100644 index 0000000000000000000000000000000000000000..240c375907d261ca2d2d25091267b269db24d3d4 --- /dev/null +++ b/source/ast/declaration/FunctionNode.cc @@ -0,0 +1,155 @@ +#include "ast/declaration/FunctionNode.h" + +namespace funk +{ + +FunctionNode::FunctionNode(const SourceLocation& location, bool is_mutable, const String& identifier, + const Vector<Pair<TokenType, String>>& parameters, BlockNode* body) : + Node(location), is_mutable(is_mutable), is_pattern(false), identifier(identifier), parameters(parameters), + body(body) +{ +} + +FunctionNode::FunctionNode(const SourceLocation& location, bool is_mutable, const String& identifier, + const Vector<ExpressionNode*>& pattern_values, BlockNode* body) : + Node(location), is_mutable(is_mutable), is_pattern(true), identifier(identifier), pattern_values(pattern_values), + body(body) +{ +} + +FunctionNode::~FunctionNode() +{ + delete body; + if (is_pattern) + { + for (ExpressionNode* pattern : pattern_values) { delete pattern; } + } +} + +Node* FunctionNode::evaluate() const +{ + // Register the function in the registry + Registry::instance().add_function(const_cast<FunctionNode*>(this)); + // Add the function to the current scope + Scope::instance().add(identifier, const_cast<FunctionNode*>(this)); + + return const_cast<FunctionNode*>(this); +} + +String FunctionNode::to_s() const +{ + String repr{(is_mutable ? "mut " : "") + String("funk ") + identifier + " = ("}; + + if (is_pattern) + { + for (size_t i{0}; i < pattern_values.size(); i++) + { + repr += pattern_values[i]->to_s(); + if (i < pattern_values.size() - 1) { repr += ", "; } + } + } + else + { + for (size_t i{0}; i < parameters.size(); i++) + { + repr += token_type_to_s(parameters[i].first) + " " + parameters[i].second; + if (i < parameters.size() - 1) { repr += ", "; } + } + } + + repr += ") {\n" + body->to_s() + "}"; + return repr; +} + +Node* FunctionNode::call(const Vector<ExpressionNode*>& arguments) const +{ + // Push new scope + Scope::instance().push(); + try + { + // Add parameters to current scope + init_param_scope(arguments); + // Evaluate body + Node* result{body->evaluate()}; + // Pop scope + Scope::instance().pop(); + return result; + } + catch (...) + { + // Pop scope even if an exception was thrown + Scope::instance().pop(); + throw; + } +} + +bool FunctionNode::is_mutable_function() const +{ + return is_mutable; +} + +String FunctionNode::get_identifier() const +{ + return identifier; +} + +const Vector<Pair<TokenType, String>>& FunctionNode::get_parameters() const +{ + return parameters; +} + +BlockNode* FunctionNode::get_body() const +{ + return body; +} + +bool FunctionNode::is_pattern_matching() const +{ + return is_pattern; +} + +bool FunctionNode::matches(const Vector<ExpressionNode*>& arguments) const +{ + // Only match if it's a pattern matching function and the number of arguments matches the number of pattern values + if (!is_pattern || arguments.size() != pattern_values.size()) { return false; } + + for (size_t i{0}; i < arguments.size(); i++) + { + NodeValue pattern_value{pattern_values[i]->get_value()}; + NodeValue argument_value{arguments[i]->get_value()}; + // Check if the pattern value and argument value match + if ((pattern_value != argument_value).cast<bool>()) { return false; } + } + + return true; +} + +void FunctionNode::init_param_scope(const Vector<ExpressionNode*>& arguments) const +{ + // For pattern matching functions, don't add parameters to the scope + // Pattern is already checked in matches() + if (is_pattern) { return; } + + size_t p_count{parameters.size()}; + size_t a_count{arguments.size()}; + + // Check if the number of arguments matches the number of parameters + if (a_count != p_count) + { + throw RuntimeError(location, + "Function '" + identifier + "' expects " + to_str(p_count) + " arguments, but got " + to_str(a_count)); + } + + for (size_t i{0}; i < p_count; i++) + { + // Evaluate argument + ExpressionNode* expr{dynamic_cast<ExpressionNode*>(arguments[i]->evaluate())}; + // Check if the argument is an expression + if (!expr) { throw RuntimeError(location, "Argument " + to_str(i) + " did not evaluate to an expression"); } + // Add argument to scope + VariableNode* var{new VariableNode(location, parameters[i].second, false, parameters[i].first, expr)}; + Scope::instance().add(parameters[i].second, var); + } +} + +} // namespace funk diff --git a/source/ast/declaration/VariableNode.cc b/source/ast/declaration/VariableNode.cc index 498077196829ce45259464bcb337bc4d76ca8090..c4fbc23c2f1856246de52ca9536df0051a626f72 100644 --- a/source/ast/declaration/VariableNode.cc +++ b/source/ast/declaration/VariableNode.cc @@ -33,8 +33,8 @@ Node* VariableNode::evaluate() const String VariableNode::to_s() const { - if (value == nullptr) { return "Variable: " + identifier; } - return "Variable: " + identifier + " = " + value->to_s(); + if (value == nullptr) { return identifier; } + return value->to_s(); } NodeValue VariableNode::get_value() const diff --git a/source/ast/expression/CallNode.cc b/source/ast/expression/CallNode.cc index eaba781f56fbb78c565db40c5fe8d4fbab000710..c75899c08f00fd7014e217a28a4ad2049c6a4065 100644 --- a/source/ast/expression/CallNode.cc +++ b/source/ast/expression/CallNode.cc @@ -15,8 +15,31 @@ CallNode::~CallNode() Node* CallNode::evaluate() const { LOG_DEBUG("Evaluating call to " + identifier.get_lexeme()); + + // Check the registry first for pattern matching and overloaded functions + FunctionNode* func{Registry::instance().get_function(identifier.get_lexeme(), args)}; + if (func) + { + LOG_DEBUG("Found function in registry: " + func->get_identifier()); + return func->call(args); + } + + // // Check the current scope next for regular functions + // func = dynamic_cast<FunctionNode*>(Scope::instance().get(identifier.get_lexeme())); + // if (func) + // { + // LOG_DEBUG("Found function in scope: " + func->get_identifier()); + // return func->call(args); + // } + + // Finally, check the built-in functions auto it = BuiltIn::functions.find(identifier.get_lexeme()); - if (it != BuiltIn::functions.end()) { return it->second(*this, args); } + if (it != BuiltIn::functions.end()) + { + LOG_DEBUG("Found built-in function: " + identifier.get_lexeme()); + return it->second(*this, args); + } + throw RuntimeError(location, "Unknown function: " + identifier.get_lexeme()); } @@ -41,4 +64,14 @@ NodeValue CallNode::get_value() const return result->get_value(); } +const Token& CallNode::get_identifier() const +{ + return identifier; +} + +const Vector<ExpressionNode*>& CallNode::get_args() const +{ + return args; +} + } // namespace funk diff --git a/source/ast/expression/PipeNode.cc b/source/ast/expression/PipeNode.cc new file mode 100644 index 0000000000000000000000000000000000000000..a0a319aa0c79fe189bbdb2631378abf7f27c7408 --- /dev/null +++ b/source/ast/expression/PipeNode.cc @@ -0,0 +1,65 @@ +#include "ast/expression/PipeNode.h" + +namespace funk +{ + +PipeNode::PipeNode(const SourceLocation& location, ExpressionNode* source, ExpressionNode* target) : + ExpressionNode(location), source(source), target(target) +{ +} + +PipeNode::~PipeNode() +{ + delete source; + delete target; +} + +Node* PipeNode::evaluate() const +{ + // Evaluate the source expression + ExpressionNode* current{dynamic_cast<ExpressionNode*>(source->evaluate())}; + if (!current) { throw RuntimeError(location, "Pipe source did not evaluate to an expression"); } + + // Create a list of arguments for the target function, starting with the source expression + Vector<ExpressionNode*> args{current}; + + if (auto call = dynamic_cast<CallNode*>(target)) + { + // Get the function's name and original arguments + const String name = call->get_identifier().get_lexeme(); + const Vector<ExpressionNode*> call_args = call->get_args(); + + // Add the original call arguments after the piped value + args.insert(args.end(), call_args.begin(), call_args.end()); + + // Create a new call node with the updated arguments + CallNode* new_call{new CallNode(call->get_identifier(), args)}; + + // Evaluate the new call node + return new_call->evaluate(); + } + else if (auto func = dynamic_cast<FunctionNode*>(target)) { return func->call(args); } + else { throw RuntimeError(location, "Pipe target must be a function or function identifier"); } +} + +String PipeNode::to_s() const +{ + return source->to_s() + " >> " + target->to_s(); +} + +NodeValue PipeNode::get_value() const +{ + return target->get_value(); +} + +ExpressionNode* PipeNode::get_source() const +{ + return source; +} + +ExpressionNode* PipeNode::get_target() const +{ + return target; +} + +} // namespace funk diff --git a/source/parser/Parser.cc b/source/parser/Parser.cc index d9ab86205fa1ad72eb41e2ecf3a4ab7d9d754e4f..a32b8ecf677fb605f29aa276ae7311035df0d0bf 100644 --- a/source/parser/Parser.cc +++ b/source/parser/Parser.cc @@ -70,6 +70,12 @@ bool Parser::check(TokenType type) const return tokens.at(index).get_type() == type; } +bool Parser::check_next(TokenType type) const +{ + if (done()) return false; + return tokens.at(index + 1).get_type() == type; +} + bool Parser::match(TokenType expected) { if (!check(expected)) return false; @@ -81,12 +87,12 @@ Node* Parser::parse_statement() { LOG_DEBUG("Parse statement"); - if (check(TokenType::COMMENT) || check(TokenType::BLOCK_COMMENT)) + if (match(TokenType::COMMENT) || match(TokenType::BLOCK_COMMENT)) { LOG_INFO("Skipping comment"); - next(); return nullptr; } + // Empty statement, just a semicolon if (match(TokenType::SEMICOLON)) { return new LiteralNode(peek_prev().get_location(), NodeValue(None())); } @@ -114,41 +120,122 @@ Node* Parser::parse_declaration() { LOG_DEBUG("Parse declaration"); - bool mut{false}; - if (check(TokenType::MUT)) - { - mut = true; - next(); - } + bool is_mutable{false}; + if (match(TokenType::MUT)) { is_mutable = true; } + if (check(TokenType::NUMB_TYPE) || check(TokenType::REAL_TYPE) || check(TokenType::BOOL_TYPE) || check(TokenType::CHAR_TYPE) || check(TokenType::TEXT_TYPE)) { - Token type{next()}; - if (check(TokenType::IDENTIFIER)) - { - Token identifier{next()}; + return parse_variable_declaration(is_mutable); + } - if (match(TokenType::ASSIGN)) - { - // needs to be parse_statement otherwise whines about semicolon dunno if I do it correctly - Node* expr{parse_statement()}; - if (!expr) { throw SyntaxError(peek_prev().get_location(), "Expected expression after '='"); } + if (check(TokenType::FUNK)) { return parse_function_declaration(is_mutable); } + + return nullptr; +} + +Node* Parser::parse_variable_declaration(bool is_mutable) +{ + LOG_DEBUG("Parse variable declaration"); - ExpressionNode* expr_node = dynamic_cast<ExpressionNode*>(expr); + Token type{next()}; - return new DeclarationNode( - type.get_location(), mut, type.get_type(), identifier.get_lexeme(), expr_node); - } - return new DeclarationNode(type.get_location(), mut, type.get_type(), identifier.get_lexeme()); - } + if (!check(TokenType::IDENTIFIER)) { return nullptr; } + Token identifier{next()}; + + if (!match(TokenType::ASSIGN)) + { + return new DeclarationNode(type.get_location(), is_mutable, type.get_type(), identifier.get_lexeme()); } - return nullptr; + Node* expr{parse_statement()}; + if (!expr) { throw SyntaxError(peek_prev().get_location(), "Expected expression after '='"); } + + ExpressionNode* expr_node{dynamic_cast<ExpressionNode*>(expr)}; + return new DeclarationNode(type.get_location(), is_mutable, type.get_type(), identifier.get_lexeme(), expr_node); +} + +Node* Parser::parse_function_declaration(bool is_mutable) +{ + LOG_DEBUG("Parse function declaration"); + + if (!match(TokenType::FUNK)) { return nullptr; } + if (!check(TokenType::IDENTIFIER)) { throw SyntaxError(peek().get_location(), "Expected function name"); } + + Token identifier{next()}; + + if (!match(TokenType::ASSIGN)) { throw SyntaxError(peek().get_location(), "Expected '=' after function name"); } + if (!match(TokenType::L_PAR)) { throw SyntaxError(peek().get_location(), "Expected '('"); } + + // Check if it's a pattern matching function + if (check(TokenType::NUMB) || check(TokenType::REAL) || check(TokenType::BOOL) || check(TokenType::CHAR) || + check(TokenType::TEXT)) + { + LOG_DEBUG("Parse pattern matching function"); + Vector<ExpressionNode*> pattern{}; + + // Collect pattern arguments + do { + pattern.push_back(dynamic_cast<LiteralNode*>(parse_literal())); + } while (match(TokenType::COMMA)); + + if (!match(TokenType::R_PAR)) { throw SyntaxError(peek().get_location(), "Expected ')' after pattern"); } + + // Parse function body + BlockNode* body{dynamic_cast<BlockNode*>(parse_block())}; + if (!body) { throw SyntaxError(peek().get_location(), "Expected function body"); } + + return new FunctionNode(identifier.get_location(), is_mutable, identifier.get_lexeme(), pattern, body); + } + else + { + LOG_DEBUG("Parse regular function"); + Vector<Pair<TokenType, String>> parameters{}; + + // Collect parameters + if (!check(TokenType::R_PAR)) + { + do { + TokenType type; + switch (peek().get_type()) + { + case TokenType::NUMB_TYPE: type = TokenType::NUMB_TYPE; break; + case TokenType::REAL_TYPE: type = TokenType::REAL_TYPE; break; + case TokenType::BOOL_TYPE: type = TokenType::BOOL_TYPE; break; + case TokenType::CHAR_TYPE: type = TokenType::CHAR_TYPE; break; + case TokenType::TEXT_TYPE: type = TokenType::TEXT_TYPE; break; + default: throw SyntaxError(peek().get_location(), "Expected parameter type"); + } + + next(); + + // Parse parameter name + if (!check(TokenType::IDENTIFIER)) + { + throw SyntaxError(peek().get_location(), "Expected parameter name"); + } + + // Add parameter to list + parameters.push_back({type, next().get_lexeme()}); + + } while (match(TokenType::COMMA)); + } + + if (!match(TokenType::R_PAR)) { throw SyntaxError(peek().get_location(), "Expected ')' after parameters"); } + + // Parse function body + BlockNode* body{dynamic_cast<BlockNode*>(parse_block())}; + if (!body) { throw SyntaxError(peek().get_location(), "Expected function body"); } + + return new FunctionNode(identifier.get_location(), is_mutable, identifier.get_lexeme(), parameters, body); + } } Node* Parser::parse_block() { LOG_DEBUG("Parse block"); + + SourceLocation start{peek().get_location()}; if (!match(TokenType::L_BRACE)) { throw SyntaxError(peek().get_location(), "Expected '{'"); } Vector<Node*> statements{}; @@ -156,20 +243,23 @@ Node* Parser::parse_block() if (!match(TokenType::R_BRACE)) { throw SyntaxError(peek().get_location(), "Expected '}'"); } - return new BlockNode(statements.at(0)->get_location(), statements); + return new BlockNode(start, statements); } Node* Parser::parse_control() { LOG_DEBUG("Parse control flow"); + if (match(TokenType::IF)) { return parse_if(); } if (match(TokenType::WHILE)) { return parse_while(); } + if (match(TokenType::RETURN)) { return parse_return(); } return nullptr; } Node* Parser::parse_if() { LOG_DEBUG("Parse if"); + if (!match(TokenType::L_PAR)) { throw SyntaxError(peek().get_location(), "Expected '('"); } ExpressionNode* condition{dynamic_cast<ExpressionNode*>(parse_expression())}; @@ -188,6 +278,7 @@ Node* Parser::parse_if() Node* Parser::parse_while() { LOG_DEBUG("Parse while loop"); + if (!match(TokenType::L_PAR)) { throw SyntaxError(peek().get_location(), "Expected '('"); } ExpressionNode* condition{dynamic_cast<ExpressionNode*>(parse_expression())}; if (!match(TokenType::R_PAR)) { throw SyntaxError(peek().get_location(), "Expected ')'"); } @@ -195,21 +286,72 @@ Node* Parser::parse_while() return new WhileNode(condition, body); } +Node* Parser::parse_return() +{ + LOG_DEBUG("Parse return"); + + if (match(TokenType::SEMICOLON)) { return new ReturnNode(peek_prev().get_location(), nullptr); } + ExpressionNode* value{dynamic_cast<ExpressionNode*>(parse_expression())}; + if (!value) { throw SyntaxError(peek().get_location(), "Expected expression"); } + if (!match(TokenType::SEMICOLON)) { throw SyntaxError(peek().get_location(), "Expected ';'"); } + return new ReturnNode(peek_prev().get_location(), value); +} + Node* Parser::parse_expression() { LOG_DEBUG("Parse expression"); + return parse_assignment(); } Node* Parser::parse_assignment() { LOG_DEBUG("Parse assignment"); - return parse_logical_or(); + + return parse_pipe(); +} + +Node* Parser::parse_pipe() +{ + LOG_DEBUG("Parse pipe"); + + Node* expr{parse_logical_or()}; + while (check(TokenType::PIPE)) + { + SourceLocation loc{next().get_location()}; + ExpressionNode* source{dynamic_cast<ExpressionNode*>(expr)}; + if (!source) { throw SyntaxError(peek().get_location(), "Left side of pipe must be an expression"); } + + if (!check(TokenType::IDENTIFIER)) + { + throw SyntaxError(peek().get_location(), "Expected function or function call after pipe operator"); + } + + Token identifier{next()}; + Vector<ExpressionNode*> args{}; + + if (match(TokenType::L_PAR)) + { + if (!check(TokenType::R_PAR)) + { + do { + args.push_back(dynamic_cast<ExpressionNode*>(parse_expression())); + } while (match(TokenType::COMMA)); + } + + if (!match(TokenType::R_PAR)) { throw SyntaxError(peek().get_location(), "Expected ')' after arguments"); } + } + + expr = new PipeNode(loc, source, new CallNode(identifier, args)); + } + + return expr; } Node* Parser::parse_logical_or() { LOG_DEBUG("Parse logical or"); + ExpressionNode* left{dynamic_cast<ExpressionNode*>(parse_logical_and())}; while (match(TokenType::OR)) @@ -225,6 +367,7 @@ Node* Parser::parse_logical_or() Node* Parser::parse_logical_and() { LOG_DEBUG("Parse logical and"); + ExpressionNode* left{dynamic_cast<ExpressionNode*>(parse_equality())}; while (match(TokenType::AND)) @@ -240,6 +383,7 @@ Node* Parser::parse_logical_and() Node* Parser::parse_equality() { LOG_DEBUG("Parse equality"); + ExpressionNode* left{dynamic_cast<ExpressionNode*>(parse_comparison())}; while (match(TokenType::EQUAL) || match(TokenType::NOT_EQUAL)) @@ -255,6 +399,7 @@ Node* Parser::parse_equality() Node* Parser::parse_comparison() { LOG_DEBUG("Parse comparison"); + ExpressionNode* left{dynamic_cast<ExpressionNode*>(parse_additive())}; while (match(TokenType::LESS) || match(TokenType::LESS_EQUAL) || match(TokenType::GREATER) || @@ -271,6 +416,7 @@ Node* Parser::parse_comparison() Node* Parser::parse_additive() { LOG_DEBUG("Parse addative"); + ExpressionNode* left{dynamic_cast<ExpressionNode*>(parse_multiplicative())}; while (match(TokenType::PLUS) || match(TokenType::MINUS)) @@ -286,6 +432,7 @@ Node* Parser::parse_additive() Node* Parser::parse_multiplicative() { LOG_DEBUG("Parse multiplicative"); + ExpressionNode* left{dynamic_cast<ExpressionNode*>(parse_unary())}; while ( @@ -302,6 +449,7 @@ Node* Parser::parse_multiplicative() Node* Parser::parse_unary() { LOG_DEBUG("Parse unary"); + if (match(TokenType::MINUS) || match(TokenType::NOT)) { Token op{peek_prev()}; @@ -315,6 +463,7 @@ Node* Parser::parse_unary() Node* Parser::parse_factor() { LOG_DEBUG("Parse factor"); + Node* expr = nullptr; if (check(TokenType::IDENTIFIER)) { expr = parse_identifier(); } @@ -339,6 +488,7 @@ Node* Parser::parse_factor() Node* Parser::parse_literal() { LOG_DEBUG("Parse literal"); + Token literal{next()}; return new LiteralNode(literal.get_location(), NodeValue(literal.get_value())); } @@ -346,6 +496,7 @@ Node* Parser::parse_literal() Node* Parser::parse_identifier() { LOG_DEBUG("Parse identifier"); + Token identifier{next()}; if (check(TokenType::L_PAR)) { return parse_call(identifier); } @@ -378,6 +529,7 @@ Node* Parser::parse_call(const Token& identifier) Node* Parser::parse_method_call(ExpressionNode* object) { LOG_DEBUG("Parse method call"); + if (!check(TokenType::IDENTIFIER)) { throw SyntaxError(peek().get_location(), "Expected method name after '.'"); } Token method{next()}; @@ -400,6 +552,7 @@ Node* Parser::parse_method_call(ExpressionNode* object) Node* Parser::parse_list() { LOG_DEBUG("Parse list"); + if (!match(TokenType::L_BRACKET)) { throw SyntaxError(peek().get_location(), "Expected '['"); } Vector<ExpressionNode*> elements{}; diff --git a/source/parser/Registry.cc b/source/parser/Registry.cc new file mode 100644 index 0000000000000000000000000000000000000000..7bfc6c03680698cc4f99f845c428176f7c2bc052 --- /dev/null +++ b/source/parser/Registry.cc @@ -0,0 +1,52 @@ +#include "parser/Registry.h" + +namespace funk +{ + +Registry& Registry::instance() +{ + static Registry registry; + return registry; +} + +bool Registry::add_function(FunctionNode* function) +{ + // TODO: Check for duplicate functions for the same identifier and arguments + functions[function->get_identifier()].push_back(function); + return true; +} + +FunctionNode* Registry::get_function(const String& identifier, const Vector<ExpressionNode*>& arguments) const +{ + // Check if the function exists + if (functions.find(identifier) == functions.end()) { return nullptr; } + + // Check if the function is a pattern matching function + for (FunctionNode* function : functions.at(identifier)) + { + if (function->is_pattern_matching() && function->matches(arguments)) { return function; } + } + + // Check if the function is a regular function + for (FunctionNode* function : functions.at(identifier)) + { + if (!function->is_pattern_matching() && function->get_parameters().size() == arguments.size()) + { + return function; + } + } + + return nullptr; +} + +void Registry::remove_function(const String& identifier) +{ + functions.erase(identifier); +} + +bool Registry::contains(const String& identifier) const +{ + return functions.find(identifier) != functions.end(); +} + +} // namespace funk diff --git a/source/parser/Scope.cc b/source/parser/Scope.cc index bd38505059cbfe6efdd1afcf54cbde9b89e67efa..8299717505d5c833cff09776c7c8e63990d1c3e7 100644 --- a/source/parser/Scope.cc +++ b/source/parser/Scope.cc @@ -47,4 +47,13 @@ Node* Scope::get(const String& name) const return nullptr; } +bool Scope::contains(const String& name) const +{ + for (int i = scopes.size() - 1; i >= 0; i--) + { + if (scopes[i].find(name) != scopes[i].end()) { return true; } + } + return false; +} + } // namespace funk