Skip to content
Snippets Groups Projects
Commit 7abc93c2 authored by Mattias Ajander's avatar Mattias Ajander
Browse files

Implement functions, return statements, and pipe operator along with a number of other fixes.

parent 5b528ecb
No related branches found
No related tags found
No related merge requests found
Showing
with 711 additions and 43 deletions
funk fib = (0) { return 0 }; funk fib = (0) { return 0; };
funk fib = (1) { return 1 }; funk fib = (1) { return 1; };
funk fib = (numb n) { return fib(n - 1) + fib(n - 2) }; funk fib = (numb n) { return fib(n - 1) + fib(n - 2); };
funk main = { 20 >> fib >> print;
mut numb i = 0;
while (i < 100) {
fib(i) >> print;
i += 1;
}
};
main();
#pragma once #pragma once
#include "ast/Node.h" #include "ast/Node.h"
#include "ast/control/ReturnNode.h"
#include "parser/Scope.h" #include "parser/Scope.h"
namespace funk namespace funk
......
#pragma once
#include "ast/control/ControlNode.h"
#include "ast/expression/ExpressionNode.h"
namespace funk
{
class ReturnNode : public ControlNode
{
public:
ReturnNode(const SourceLocation& location, ExpressionNode* value);
~ReturnNode() override;
Node* evaluate() const override;
String to_s() const override;
ExpressionNode* get_value() const;
private:
ExpressionNode* value;
};
} // namespace funk
#pragma once
#include "ast/BlockNode.h"
#include "ast/declaration/VariableNode.h"
#include "ast/expression/ExpressionNode.h"
#include "parser/Registry.h"
#include "parser/Scope.h"
#include "token/Token.h"
namespace funk
{
class FunctionNode : public Node
{
public:
FunctionNode(const SourceLocation& location, bool is_mutable, const String& identifier,
const Vector<Pair<TokenType, String>>& parameters, BlockNode* body);
FunctionNode(const SourceLocation& location, bool is_mutable, const String& identifier,
const Vector<ExpressionNode*>& pattern_values, BlockNode* body);
~FunctionNode() override;
Node* evaluate() const override;
String to_s() const override;
Node* call(const Vector<ExpressionNode*>& arguments) const;
bool is_mutable_function() const;
String get_identifier() const;
const Vector<Pair<TokenType, String>>& get_parameters() const;
BlockNode* get_body() const;
bool is_pattern_matching() const;
bool matches(const Vector<ExpressionNode*>& arguments) const;
private:
bool is_mutable;
bool is_pattern;
String identifier;
Vector<Pair<TokenType, String>> parameters;
Vector<ExpressionNode*> pattern_values;
BlockNode* body;
void init_param_scope(const Vector<ExpressionNode*>& arguments) const;
};
} // namespace funk
#pragma once #pragma once
#include "ast/declaration/FunctionNode.h"
#include "ast/expression/ExpressionNode.h" #include "ast/expression/ExpressionNode.h"
#include "ast/expression/LiteralNode.h" #include "ast/expression/LiteralNode.h"
#include "logging/LogMacros.h" #include "logging/LogMacros.h"
#include "parser/BuiltIn.h" #include "parser/BuiltIn.h"
#include "parser/Registry.h"
#include "parser/Scope.h"
#include "token/Token.h" #include "token/Token.h"
namespace funk namespace funk
...@@ -19,6 +22,8 @@ public: ...@@ -19,6 +22,8 @@ public:
String to_s() const override; String to_s() const override;
NodeValue get_value() const override; NodeValue get_value() const override;
const Token& get_identifier() const;
const Vector<ExpressionNode*>& get_args() const;
protected: protected:
Token identifier; Token identifier;
......
#pragma once
#include "ast/declaration/FunctionNode.h"
#include "ast/expression/CallNode.h"
#include "ast/expression/ExpressionNode.h"
namespace funk
{
class PipeNode : public ExpressionNode
{
public:
PipeNode(const SourceLocation& location, ExpressionNode* source, ExpressionNode* target);
~PipeNode() override;
Node* evaluate() const override;
String to_s() const override;
NodeValue get_value() const override;
ExpressionNode* get_source() const;
ExpressionNode* get_target() const;
private:
ExpressionNode* source;
ExpressionNode* target;
};
} // namespace funk
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "ast/expression/ListNode.h" #include "ast/expression/ListNode.h"
#include "ast/expression/LiteralNode.h" #include "ast/expression/LiteralNode.h"
#include "ast/expression/MethodCallNode.h" #include "ast/expression/MethodCallNode.h"
#include "ast/expression/PipeNode.h"
#include "ast/expression/UnaryOpNode.h" #include "ast/expression/UnaryOpNode.h"
#include "lexer/Lexer.h" #include "lexer/Lexer.h"
#include "logging/LogMacros.h" #include "logging/LogMacros.h"
...@@ -109,6 +110,13 @@ private: ...@@ -109,6 +110,13 @@ private:
*/ */
bool check(TokenType type) const; bool check(TokenType type) const;
/**
* @brief Checks if the next token is of the expected type
* @param type The token type to check for
* @return bool True if the next token matches the expected type
*/
bool check_next(TokenType type) const;
/** /**
* @brief Checks if the current token matches the expected type, and advances if it does * @brief Checks if the current token matches the expected type, and advances if it does
* @param expected The token type to check for * @param expected The token type to check for
...@@ -128,6 +136,18 @@ private: ...@@ -128,6 +136,18 @@ private:
*/ */
Node* parse_declaration(); Node* parse_declaration();
/**
* @brief Parses a variable declaration
* @return Node* The AST node representing the variable declaration
*/
Node* parse_variable_declaration(bool is_mutable);
/**
* @brief Parses a function declaration
* @return Node* The AST node representing the function declaration
*/
Node* parse_function_declaration(bool is_mutable);
/** /**
* @brief Parses a block of statements * @brief Parses a block of statements
* @return Node* The AST node representing the block * @return Node* The AST node representing the block
...@@ -152,6 +172,12 @@ private: ...@@ -152,6 +172,12 @@ private:
*/ */
Node* parse_while(); Node* parse_while();
/**
* @brief Parses a return statement
* @return Node* The AST node representing the return statement
*/
Node* parse_return();
/** /**
* @brief Parses an expression * @brief Parses an expression
* @return Node* The AST node representing the expression * @return Node* The AST node representing the expression
...@@ -164,6 +190,12 @@ private: ...@@ -164,6 +190,12 @@ private:
*/ */
Node* parse_assignment(); Node* parse_assignment();
/**
* @brief Parses a pipe expression
* @return Node* The AST node representing the pipe expression
*/
Node* parse_pipe();
/** /**
* @brief Parses a logical OR expression * @brief Parses a logical OR expression
* @return Node* The AST node representing the logical OR expression * @return Node* The AST node representing the logical OR expression
......
#pragma once
#include "ast/declaration/FunctionNode.h"
namespace funk
{
/**
* @brief Forward declaration of FunctionNode.
*/
class FunctionNode;
class Registry
{
public:
static Registry& instance();
bool add_function(FunctionNode* node);
FunctionNode* get_function(const String& identifier, const Vector<ExpressionNode*>& arguments) const;
void remove_function(const String& identifier);
bool contains(const String& identifier) const;
private:
Registry() = default;
~Registry() = default;
HashMap<String, Vector<FunctionNode*>> functions;
};
} // namespace funk
...@@ -17,6 +17,7 @@ public: ...@@ -17,6 +17,7 @@ public:
void add(const String& name, Node* node); void add(const String& name, Node* node);
Node* get(const String& name) const; Node* get(const String& name) const;
bool contains(const String& name) const;
private: private:
Scope(); Scope();
......
...@@ -49,6 +49,13 @@ template <typename T> using Vector = std::vector<T>; ...@@ -49,6 +49,13 @@ template <typename T> using Vector = std::vector<T>;
*/ */
template <typename K, typename V> using HashMap = std::unordered_map<K, V>; template <typename K, typename V> using HashMap = std::unordered_map<K, V>;
/**
* @brief Template alias for std::pair.
* @tparam K The key type
* @tparam V The value type
*/
template <typename K, typename V> using Pair = std::pair<K, V>;
/** /**
* @brief Macro for std::to_string. * @brief Macro for std::to_string.
*/ */
......
...@@ -13,7 +13,7 @@ BlockNode::~BlockNode() ...@@ -13,7 +13,7 @@ BlockNode::~BlockNode()
void BlockNode::add(Node* statement) void BlockNode::add(Node* statement)
{ {
if (statement == nullptr) if (!statement)
{ {
LOG_WARN("Attempted to add null statement to block"); LOG_WARN("Attempted to add null statement to block");
return; return;
...@@ -30,7 +30,12 @@ Node* BlockNode::evaluate() const ...@@ -30,7 +30,12 @@ Node* BlockNode::evaluate() const
{ {
Scope::instance().push(); Scope::instance().push();
Node* result{}; Node* result{};
for (Node* statement : statements) { result = statement->evaluate(); } for (Node* statement : statements)
{
result = statement->evaluate();
if (dynamic_cast<ReturnNode*>(statement)) { break; }
result = nullptr;
}
Scope::instance().pop(); Scope::instance().pop();
return result; return result;
} }
......
#include "ast/control/ReturnNode.h"
#include "utils/Common.h"
namespace funk
{
ReturnNode::ReturnNode(const SourceLocation& location, ExpressionNode* value) : ControlNode(location), value(value) {}
ReturnNode::~ReturnNode()
{
delete value;
}
Node* ReturnNode::evaluate() const
{
return value ? value->evaluate() : nullptr;
}
String ReturnNode::to_s() const
{
if (!value) { return "return"; }
return "return " + value->to_s();
}
ExpressionNode* ReturnNode::get_value() const
{
return value;
}
} // namespace funk
#include "ast/declaration/FunctionNode.h"
namespace funk
{
FunctionNode::FunctionNode(const SourceLocation& location, bool is_mutable, const String& identifier,
const Vector<Pair<TokenType, String>>& parameters, BlockNode* body) :
Node(location), is_mutable(is_mutable), is_pattern(false), identifier(identifier), parameters(parameters),
body(body)
{
}
FunctionNode::FunctionNode(const SourceLocation& location, bool is_mutable, const String& identifier,
const Vector<ExpressionNode*>& pattern_values, BlockNode* body) :
Node(location), is_mutable(is_mutable), is_pattern(true), identifier(identifier), pattern_values(pattern_values),
body(body)
{
}
FunctionNode::~FunctionNode()
{
delete body;
if (is_pattern)
{
for (ExpressionNode* pattern : pattern_values) { delete pattern; }
}
}
Node* FunctionNode::evaluate() const
{
// Register the function in the registry
Registry::instance().add_function(const_cast<FunctionNode*>(this));
// Add the function to the current scope
Scope::instance().add(identifier, const_cast<FunctionNode*>(this));
return const_cast<FunctionNode*>(this);
}
String FunctionNode::to_s() const
{
String repr{(is_mutable ? "mut " : "") + String("funk ") + identifier + " = ("};
if (is_pattern)
{
for (size_t i{0}; i < pattern_values.size(); i++)
{
repr += pattern_values[i]->to_s();
if (i < pattern_values.size() - 1) { repr += ", "; }
}
}
else
{
for (size_t i{0}; i < parameters.size(); i++)
{
repr += token_type_to_s(parameters[i].first) + " " + parameters[i].second;
if (i < parameters.size() - 1) { repr += ", "; }
}
}
repr += ") {\n" + body->to_s() + "}";
return repr;
}
Node* FunctionNode::call(const Vector<ExpressionNode*>& arguments) const
{
// Push new scope
Scope::instance().push();
try
{
// Add parameters to current scope
init_param_scope(arguments);
// Evaluate body
Node* result{body->evaluate()};
// Pop scope
Scope::instance().pop();
return result;
}
catch (...)
{
// Pop scope even if an exception was thrown
Scope::instance().pop();
throw;
}
}
bool FunctionNode::is_mutable_function() const
{
return is_mutable;
}
String FunctionNode::get_identifier() const
{
return identifier;
}
const Vector<Pair<TokenType, String>>& FunctionNode::get_parameters() const
{
return parameters;
}
BlockNode* FunctionNode::get_body() const
{
return body;
}
bool FunctionNode::is_pattern_matching() const
{
return is_pattern;
}
bool FunctionNode::matches(const Vector<ExpressionNode*>& arguments) const
{
// Only match if it's a pattern matching function and the number of arguments matches the number of pattern values
if (!is_pattern || arguments.size() != pattern_values.size()) { return false; }
for (size_t i{0}; i < arguments.size(); i++)
{
NodeValue pattern_value{pattern_values[i]->get_value()};
NodeValue argument_value{arguments[i]->get_value()};
// Check if the pattern value and argument value match
if ((pattern_value != argument_value).cast<bool>()) { return false; }
}
return true;
}
void FunctionNode::init_param_scope(const Vector<ExpressionNode*>& arguments) const
{
// For pattern matching functions, don't add parameters to the scope
// Pattern is already checked in matches()
if (is_pattern) { return; }
size_t p_count{parameters.size()};
size_t a_count{arguments.size()};
// Check if the number of arguments matches the number of parameters
if (a_count != p_count)
{
throw RuntimeError(location,
"Function '" + identifier + "' expects " + to_str(p_count) + " arguments, but got " + to_str(a_count));
}
for (size_t i{0}; i < p_count; i++)
{
// Evaluate argument
ExpressionNode* expr{dynamic_cast<ExpressionNode*>(arguments[i]->evaluate())};
// Check if the argument is an expression
if (!expr) { throw RuntimeError(location, "Argument " + to_str(i) + " did not evaluate to an expression"); }
// Add argument to scope
VariableNode* var{new VariableNode(location, parameters[i].second, false, parameters[i].first, expr)};
Scope::instance().add(parameters[i].second, var);
}
}
} // namespace funk
...@@ -33,8 +33,8 @@ Node* VariableNode::evaluate() const ...@@ -33,8 +33,8 @@ Node* VariableNode::evaluate() const
String VariableNode::to_s() const String VariableNode::to_s() const
{ {
if (value == nullptr) { return "Variable: " + identifier; } if (value == nullptr) { return identifier; }
return "Variable: " + identifier + " = " + value->to_s(); return value->to_s();
} }
NodeValue VariableNode::get_value() const NodeValue VariableNode::get_value() const
......
...@@ -15,8 +15,31 @@ CallNode::~CallNode() ...@@ -15,8 +15,31 @@ CallNode::~CallNode()
Node* CallNode::evaluate() const Node* CallNode::evaluate() const
{ {
LOG_DEBUG("Evaluating call to " + identifier.get_lexeme()); LOG_DEBUG("Evaluating call to " + identifier.get_lexeme());
// Check the registry first for pattern matching and overloaded functions
FunctionNode* func{Registry::instance().get_function(identifier.get_lexeme(), args)};
if (func)
{
LOG_DEBUG("Found function in registry: " + func->get_identifier());
return func->call(args);
}
// // Check the current scope next for regular functions
// func = dynamic_cast<FunctionNode*>(Scope::instance().get(identifier.get_lexeme()));
// if (func)
// {
// LOG_DEBUG("Found function in scope: " + func->get_identifier());
// return func->call(args);
// }
// Finally, check the built-in functions
auto it = BuiltIn::functions.find(identifier.get_lexeme()); auto it = BuiltIn::functions.find(identifier.get_lexeme());
if (it != BuiltIn::functions.end()) { return it->second(*this, args); } if (it != BuiltIn::functions.end())
{
LOG_DEBUG("Found built-in function: " + identifier.get_lexeme());
return it->second(*this, args);
}
throw RuntimeError(location, "Unknown function: " + identifier.get_lexeme()); throw RuntimeError(location, "Unknown function: " + identifier.get_lexeme());
} }
...@@ -41,4 +64,14 @@ NodeValue CallNode::get_value() const ...@@ -41,4 +64,14 @@ NodeValue CallNode::get_value() const
return result->get_value(); return result->get_value();
} }
const Token& CallNode::get_identifier() const
{
return identifier;
}
const Vector<ExpressionNode*>& CallNode::get_args() const
{
return args;
}
} // namespace funk } // namespace funk
#include "ast/expression/PipeNode.h"
namespace funk
{
PipeNode::PipeNode(const SourceLocation& location, ExpressionNode* source, ExpressionNode* target) :
ExpressionNode(location), source(source), target(target)
{
}
PipeNode::~PipeNode()
{
delete source;
delete target;
}
Node* PipeNode::evaluate() const
{
// Evaluate the source expression
ExpressionNode* current{dynamic_cast<ExpressionNode*>(source->evaluate())};
if (!current) { throw RuntimeError(location, "Pipe source did not evaluate to an expression"); }
// Create a list of arguments for the target function, starting with the source expression
Vector<ExpressionNode*> args{current};
if (auto call = dynamic_cast<CallNode*>(target))
{
// Get the function's name and original arguments
const String name = call->get_identifier().get_lexeme();
const Vector<ExpressionNode*> call_args = call->get_args();
// Add the original call arguments after the piped value
args.insert(args.end(), call_args.begin(), call_args.end());
// Create a new call node with the updated arguments
CallNode* new_call{new CallNode(call->get_identifier(), args)};
// Evaluate the new call node
return new_call->evaluate();
}
else if (auto func = dynamic_cast<FunctionNode*>(target)) { return func->call(args); }
else { throw RuntimeError(location, "Pipe target must be a function or function identifier"); }
}
String PipeNode::to_s() const
{
return source->to_s() + " >> " + target->to_s();
}
NodeValue PipeNode::get_value() const
{
return target->get_value();
}
ExpressionNode* PipeNode::get_source() const
{
return source;
}
ExpressionNode* PipeNode::get_target() const
{
return target;
}
} // namespace funk
...@@ -70,6 +70,12 @@ bool Parser::check(TokenType type) const ...@@ -70,6 +70,12 @@ bool Parser::check(TokenType type) const
return tokens.at(index).get_type() == type; return tokens.at(index).get_type() == type;
} }
bool Parser::check_next(TokenType type) const
{
if (done()) return false;
return tokens.at(index + 1).get_type() == type;
}
bool Parser::match(TokenType expected) bool Parser::match(TokenType expected)
{ {
if (!check(expected)) return false; if (!check(expected)) return false;
...@@ -81,12 +87,12 @@ Node* Parser::parse_statement() ...@@ -81,12 +87,12 @@ Node* Parser::parse_statement()
{ {
LOG_DEBUG("Parse statement"); LOG_DEBUG("Parse statement");
if (check(TokenType::COMMENT) || check(TokenType::BLOCK_COMMENT)) if (match(TokenType::COMMENT) || match(TokenType::BLOCK_COMMENT))
{ {
LOG_INFO("Skipping comment"); LOG_INFO("Skipping comment");
next();
return nullptr; return nullptr;
} }
// Empty statement, just a semicolon // Empty statement, just a semicolon
if (match(TokenType::SEMICOLON)) { return new LiteralNode(peek_prev().get_location(), NodeValue(None())); } if (match(TokenType::SEMICOLON)) { return new LiteralNode(peek_prev().get_location(), NodeValue(None())); }
...@@ -114,41 +120,122 @@ Node* Parser::parse_declaration() ...@@ -114,41 +120,122 @@ Node* Parser::parse_declaration()
{ {
LOG_DEBUG("Parse declaration"); LOG_DEBUG("Parse declaration");
bool mut{false}; bool is_mutable{false};
if (check(TokenType::MUT)) if (match(TokenType::MUT)) { is_mutable = true; }
{
mut = true;
next();
}
if (check(TokenType::NUMB_TYPE) || check(TokenType::REAL_TYPE) || check(TokenType::BOOL_TYPE) || if (check(TokenType::NUMB_TYPE) || check(TokenType::REAL_TYPE) || check(TokenType::BOOL_TYPE) ||
check(TokenType::CHAR_TYPE) || check(TokenType::TEXT_TYPE)) check(TokenType::CHAR_TYPE) || check(TokenType::TEXT_TYPE))
{ {
Token type{next()}; return parse_variable_declaration(is_mutable);
if (check(TokenType::IDENTIFIER)) }
{
Token identifier{next()};
if (match(TokenType::ASSIGN)) if (check(TokenType::FUNK)) { return parse_function_declaration(is_mutable); }
{
// needs to be parse_statement otherwise whines about semicolon dunno if I do it correctly return nullptr;
Node* expr{parse_statement()}; }
if (!expr) { throw SyntaxError(peek_prev().get_location(), "Expected expression after '='"); }
Node* Parser::parse_variable_declaration(bool is_mutable)
{
LOG_DEBUG("Parse variable declaration");
ExpressionNode* expr_node = dynamic_cast<ExpressionNode*>(expr); Token type{next()};
return new DeclarationNode( if (!check(TokenType::IDENTIFIER)) { return nullptr; }
type.get_location(), mut, type.get_type(), identifier.get_lexeme(), expr_node); Token identifier{next()};
}
return new DeclarationNode(type.get_location(), mut, type.get_type(), identifier.get_lexeme()); if (!match(TokenType::ASSIGN))
} {
return new DeclarationNode(type.get_location(), is_mutable, type.get_type(), identifier.get_lexeme());
} }
return nullptr; Node* expr{parse_statement()};
if (!expr) { throw SyntaxError(peek_prev().get_location(), "Expected expression after '='"); }
ExpressionNode* expr_node{dynamic_cast<ExpressionNode*>(expr)};
return new DeclarationNode(type.get_location(), is_mutable, type.get_type(), identifier.get_lexeme(), expr_node);
}
Node* Parser::parse_function_declaration(bool is_mutable)
{
LOG_DEBUG("Parse function declaration");
if (!match(TokenType::FUNK)) { return nullptr; }
if (!check(TokenType::IDENTIFIER)) { throw SyntaxError(peek().get_location(), "Expected function name"); }
Token identifier{next()};
if (!match(TokenType::ASSIGN)) { throw SyntaxError(peek().get_location(), "Expected '=' after function name"); }
if (!match(TokenType::L_PAR)) { throw SyntaxError(peek().get_location(), "Expected '('"); }
// Check if it's a pattern matching function
if (check(TokenType::NUMB) || check(TokenType::REAL) || check(TokenType::BOOL) || check(TokenType::CHAR) ||
check(TokenType::TEXT))
{
LOG_DEBUG("Parse pattern matching function");
Vector<ExpressionNode*> pattern{};
// Collect pattern arguments
do {
pattern.push_back(dynamic_cast<LiteralNode*>(parse_literal()));
} while (match(TokenType::COMMA));
if (!match(TokenType::R_PAR)) { throw SyntaxError(peek().get_location(), "Expected ')' after pattern"); }
// Parse function body
BlockNode* body{dynamic_cast<BlockNode*>(parse_block())};
if (!body) { throw SyntaxError(peek().get_location(), "Expected function body"); }
return new FunctionNode(identifier.get_location(), is_mutable, identifier.get_lexeme(), pattern, body);
}
else
{
LOG_DEBUG("Parse regular function");
Vector<Pair<TokenType, String>> parameters{};
// Collect parameters
if (!check(TokenType::R_PAR))
{
do {
TokenType type;
switch (peek().get_type())
{
case TokenType::NUMB_TYPE: type = TokenType::NUMB_TYPE; break;
case TokenType::REAL_TYPE: type = TokenType::REAL_TYPE; break;
case TokenType::BOOL_TYPE: type = TokenType::BOOL_TYPE; break;
case TokenType::CHAR_TYPE: type = TokenType::CHAR_TYPE; break;
case TokenType::TEXT_TYPE: type = TokenType::TEXT_TYPE; break;
default: throw SyntaxError(peek().get_location(), "Expected parameter type");
}
next();
// Parse parameter name
if (!check(TokenType::IDENTIFIER))
{
throw SyntaxError(peek().get_location(), "Expected parameter name");
}
// Add parameter to list
parameters.push_back({type, next().get_lexeme()});
} while (match(TokenType::COMMA));
}
if (!match(TokenType::R_PAR)) { throw SyntaxError(peek().get_location(), "Expected ')' after parameters"); }
// Parse function body
BlockNode* body{dynamic_cast<BlockNode*>(parse_block())};
if (!body) { throw SyntaxError(peek().get_location(), "Expected function body"); }
return new FunctionNode(identifier.get_location(), is_mutable, identifier.get_lexeme(), parameters, body);
}
} }
Node* Parser::parse_block() Node* Parser::parse_block()
{ {
LOG_DEBUG("Parse block"); LOG_DEBUG("Parse block");
SourceLocation start{peek().get_location()};
if (!match(TokenType::L_BRACE)) { throw SyntaxError(peek().get_location(), "Expected '{'"); } if (!match(TokenType::L_BRACE)) { throw SyntaxError(peek().get_location(), "Expected '{'"); }
Vector<Node*> statements{}; Vector<Node*> statements{};
...@@ -156,20 +243,23 @@ Node* Parser::parse_block() ...@@ -156,20 +243,23 @@ Node* Parser::parse_block()
if (!match(TokenType::R_BRACE)) { throw SyntaxError(peek().get_location(), "Expected '}'"); } if (!match(TokenType::R_BRACE)) { throw SyntaxError(peek().get_location(), "Expected '}'"); }
return new BlockNode(statements.at(0)->get_location(), statements); return new BlockNode(start, statements);
} }
Node* Parser::parse_control() Node* Parser::parse_control()
{ {
LOG_DEBUG("Parse control flow"); LOG_DEBUG("Parse control flow");
if (match(TokenType::IF)) { return parse_if(); } if (match(TokenType::IF)) { return parse_if(); }
if (match(TokenType::WHILE)) { return parse_while(); } if (match(TokenType::WHILE)) { return parse_while(); }
if (match(TokenType::RETURN)) { return parse_return(); }
return nullptr; return nullptr;
} }
Node* Parser::parse_if() Node* Parser::parse_if()
{ {
LOG_DEBUG("Parse if"); LOG_DEBUG("Parse if");
if (!match(TokenType::L_PAR)) { throw SyntaxError(peek().get_location(), "Expected '('"); } if (!match(TokenType::L_PAR)) { throw SyntaxError(peek().get_location(), "Expected '('"); }
ExpressionNode* condition{dynamic_cast<ExpressionNode*>(parse_expression())}; ExpressionNode* condition{dynamic_cast<ExpressionNode*>(parse_expression())};
...@@ -188,6 +278,7 @@ Node* Parser::parse_if() ...@@ -188,6 +278,7 @@ Node* Parser::parse_if()
Node* Parser::parse_while() Node* Parser::parse_while()
{ {
LOG_DEBUG("Parse while loop"); LOG_DEBUG("Parse while loop");
if (!match(TokenType::L_PAR)) { throw SyntaxError(peek().get_location(), "Expected '('"); } if (!match(TokenType::L_PAR)) { throw SyntaxError(peek().get_location(), "Expected '('"); }
ExpressionNode* condition{dynamic_cast<ExpressionNode*>(parse_expression())}; ExpressionNode* condition{dynamic_cast<ExpressionNode*>(parse_expression())};
if (!match(TokenType::R_PAR)) { throw SyntaxError(peek().get_location(), "Expected ')'"); } if (!match(TokenType::R_PAR)) { throw SyntaxError(peek().get_location(), "Expected ')'"); }
...@@ -195,21 +286,72 @@ Node* Parser::parse_while() ...@@ -195,21 +286,72 @@ Node* Parser::parse_while()
return new WhileNode(condition, body); return new WhileNode(condition, body);
} }
Node* Parser::parse_return()
{
LOG_DEBUG("Parse return");
if (match(TokenType::SEMICOLON)) { return new ReturnNode(peek_prev().get_location(), nullptr); }
ExpressionNode* value{dynamic_cast<ExpressionNode*>(parse_expression())};
if (!value) { throw SyntaxError(peek().get_location(), "Expected expression"); }
if (!match(TokenType::SEMICOLON)) { throw SyntaxError(peek().get_location(), "Expected ';'"); }
return new ReturnNode(peek_prev().get_location(), value);
}
Node* Parser::parse_expression() Node* Parser::parse_expression()
{ {
LOG_DEBUG("Parse expression"); LOG_DEBUG("Parse expression");
return parse_assignment(); return parse_assignment();
} }
Node* Parser::parse_assignment() Node* Parser::parse_assignment()
{ {
LOG_DEBUG("Parse assignment"); LOG_DEBUG("Parse assignment");
return parse_logical_or();
return parse_pipe();
}
Node* Parser::parse_pipe()
{
LOG_DEBUG("Parse pipe");
Node* expr{parse_logical_or()};
while (check(TokenType::PIPE))
{
SourceLocation loc{next().get_location()};
ExpressionNode* source{dynamic_cast<ExpressionNode*>(expr)};
if (!source) { throw SyntaxError(peek().get_location(), "Left side of pipe must be an expression"); }
if (!check(TokenType::IDENTIFIER))
{
throw SyntaxError(peek().get_location(), "Expected function or function call after pipe operator");
}
Token identifier{next()};
Vector<ExpressionNode*> args{};
if (match(TokenType::L_PAR))
{
if (!check(TokenType::R_PAR))
{
do {
args.push_back(dynamic_cast<ExpressionNode*>(parse_expression()));
} while (match(TokenType::COMMA));
}
if (!match(TokenType::R_PAR)) { throw SyntaxError(peek().get_location(), "Expected ')' after arguments"); }
}
expr = new PipeNode(loc, source, new CallNode(identifier, args));
}
return expr;
} }
Node* Parser::parse_logical_or() Node* Parser::parse_logical_or()
{ {
LOG_DEBUG("Parse logical or"); LOG_DEBUG("Parse logical or");
ExpressionNode* left{dynamic_cast<ExpressionNode*>(parse_logical_and())}; ExpressionNode* left{dynamic_cast<ExpressionNode*>(parse_logical_and())};
while (match(TokenType::OR)) while (match(TokenType::OR))
...@@ -225,6 +367,7 @@ Node* Parser::parse_logical_or() ...@@ -225,6 +367,7 @@ Node* Parser::parse_logical_or()
Node* Parser::parse_logical_and() Node* Parser::parse_logical_and()
{ {
LOG_DEBUG("Parse logical and"); LOG_DEBUG("Parse logical and");
ExpressionNode* left{dynamic_cast<ExpressionNode*>(parse_equality())}; ExpressionNode* left{dynamic_cast<ExpressionNode*>(parse_equality())};
while (match(TokenType::AND)) while (match(TokenType::AND))
...@@ -240,6 +383,7 @@ Node* Parser::parse_logical_and() ...@@ -240,6 +383,7 @@ Node* Parser::parse_logical_and()
Node* Parser::parse_equality() Node* Parser::parse_equality()
{ {
LOG_DEBUG("Parse equality"); LOG_DEBUG("Parse equality");
ExpressionNode* left{dynamic_cast<ExpressionNode*>(parse_comparison())}; ExpressionNode* left{dynamic_cast<ExpressionNode*>(parse_comparison())};
while (match(TokenType::EQUAL) || match(TokenType::NOT_EQUAL)) while (match(TokenType::EQUAL) || match(TokenType::NOT_EQUAL))
...@@ -255,6 +399,7 @@ Node* Parser::parse_equality() ...@@ -255,6 +399,7 @@ Node* Parser::parse_equality()
Node* Parser::parse_comparison() Node* Parser::parse_comparison()
{ {
LOG_DEBUG("Parse comparison"); LOG_DEBUG("Parse comparison");
ExpressionNode* left{dynamic_cast<ExpressionNode*>(parse_additive())}; ExpressionNode* left{dynamic_cast<ExpressionNode*>(parse_additive())};
while (match(TokenType::LESS) || match(TokenType::LESS_EQUAL) || match(TokenType::GREATER) || while (match(TokenType::LESS) || match(TokenType::LESS_EQUAL) || match(TokenType::GREATER) ||
...@@ -271,6 +416,7 @@ Node* Parser::parse_comparison() ...@@ -271,6 +416,7 @@ Node* Parser::parse_comparison()
Node* Parser::parse_additive() Node* Parser::parse_additive()
{ {
LOG_DEBUG("Parse addative"); LOG_DEBUG("Parse addative");
ExpressionNode* left{dynamic_cast<ExpressionNode*>(parse_multiplicative())}; ExpressionNode* left{dynamic_cast<ExpressionNode*>(parse_multiplicative())};
while (match(TokenType::PLUS) || match(TokenType::MINUS)) while (match(TokenType::PLUS) || match(TokenType::MINUS))
...@@ -286,6 +432,7 @@ Node* Parser::parse_additive() ...@@ -286,6 +432,7 @@ Node* Parser::parse_additive()
Node* Parser::parse_multiplicative() Node* Parser::parse_multiplicative()
{ {
LOG_DEBUG("Parse multiplicative"); LOG_DEBUG("Parse multiplicative");
ExpressionNode* left{dynamic_cast<ExpressionNode*>(parse_unary())}; ExpressionNode* left{dynamic_cast<ExpressionNode*>(parse_unary())};
while ( while (
...@@ -302,6 +449,7 @@ Node* Parser::parse_multiplicative() ...@@ -302,6 +449,7 @@ Node* Parser::parse_multiplicative()
Node* Parser::parse_unary() Node* Parser::parse_unary()
{ {
LOG_DEBUG("Parse unary"); LOG_DEBUG("Parse unary");
if (match(TokenType::MINUS) || match(TokenType::NOT)) if (match(TokenType::MINUS) || match(TokenType::NOT))
{ {
Token op{peek_prev()}; Token op{peek_prev()};
...@@ -315,6 +463,7 @@ Node* Parser::parse_unary() ...@@ -315,6 +463,7 @@ Node* Parser::parse_unary()
Node* Parser::parse_factor() Node* Parser::parse_factor()
{ {
LOG_DEBUG("Parse factor"); LOG_DEBUG("Parse factor");
Node* expr = nullptr; Node* expr = nullptr;
if (check(TokenType::IDENTIFIER)) { expr = parse_identifier(); } if (check(TokenType::IDENTIFIER)) { expr = parse_identifier(); }
...@@ -339,6 +488,7 @@ Node* Parser::parse_factor() ...@@ -339,6 +488,7 @@ Node* Parser::parse_factor()
Node* Parser::parse_literal() Node* Parser::parse_literal()
{ {
LOG_DEBUG("Parse literal"); LOG_DEBUG("Parse literal");
Token literal{next()}; Token literal{next()};
return new LiteralNode(literal.get_location(), NodeValue(literal.get_value())); return new LiteralNode(literal.get_location(), NodeValue(literal.get_value()));
} }
...@@ -346,6 +496,7 @@ Node* Parser::parse_literal() ...@@ -346,6 +496,7 @@ Node* Parser::parse_literal()
Node* Parser::parse_identifier() Node* Parser::parse_identifier()
{ {
LOG_DEBUG("Parse identifier"); LOG_DEBUG("Parse identifier");
Token identifier{next()}; Token identifier{next()};
if (check(TokenType::L_PAR)) { return parse_call(identifier); } if (check(TokenType::L_PAR)) { return parse_call(identifier); }
...@@ -378,6 +529,7 @@ Node* Parser::parse_call(const Token& identifier) ...@@ -378,6 +529,7 @@ Node* Parser::parse_call(const Token& identifier)
Node* Parser::parse_method_call(ExpressionNode* object) Node* Parser::parse_method_call(ExpressionNode* object)
{ {
LOG_DEBUG("Parse method call"); LOG_DEBUG("Parse method call");
if (!check(TokenType::IDENTIFIER)) { throw SyntaxError(peek().get_location(), "Expected method name after '.'"); } if (!check(TokenType::IDENTIFIER)) { throw SyntaxError(peek().get_location(), "Expected method name after '.'"); }
Token method{next()}; Token method{next()};
...@@ -400,6 +552,7 @@ Node* Parser::parse_method_call(ExpressionNode* object) ...@@ -400,6 +552,7 @@ Node* Parser::parse_method_call(ExpressionNode* object)
Node* Parser::parse_list() Node* Parser::parse_list()
{ {
LOG_DEBUG("Parse list"); LOG_DEBUG("Parse list");
if (!match(TokenType::L_BRACKET)) { throw SyntaxError(peek().get_location(), "Expected '['"); } if (!match(TokenType::L_BRACKET)) { throw SyntaxError(peek().get_location(), "Expected '['"); }
Vector<ExpressionNode*> elements{}; Vector<ExpressionNode*> elements{};
......
#include "parser/Registry.h"
namespace funk
{
Registry& Registry::instance()
{
static Registry registry;
return registry;
}
bool Registry::add_function(FunctionNode* function)
{
// TODO: Check for duplicate functions for the same identifier and arguments
functions[function->get_identifier()].push_back(function);
return true;
}
FunctionNode* Registry::get_function(const String& identifier, const Vector<ExpressionNode*>& arguments) const
{
// Check if the function exists
if (functions.find(identifier) == functions.end()) { return nullptr; }
// Check if the function is a pattern matching function
for (FunctionNode* function : functions.at(identifier))
{
if (function->is_pattern_matching() && function->matches(arguments)) { return function; }
}
// Check if the function is a regular function
for (FunctionNode* function : functions.at(identifier))
{
if (!function->is_pattern_matching() && function->get_parameters().size() == arguments.size())
{
return function;
}
}
return nullptr;
}
void Registry::remove_function(const String& identifier)
{
functions.erase(identifier);
}
bool Registry::contains(const String& identifier) const
{
return functions.find(identifier) != functions.end();
}
} // namespace funk
...@@ -47,4 +47,13 @@ Node* Scope::get(const String& name) const ...@@ -47,4 +47,13 @@ Node* Scope::get(const String& name) const
return nullptr; return nullptr;
} }
bool Scope::contains(const String& name) const
{
for (int i = scopes.size() - 1; i >= 0; i--)
{
if (scopes[i].find(name) != scopes[i].end()) { return true; }
}
return false;
}
} // namespace funk } // namespace funk
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment