Skip to content
Snippets Groups Projects
Commit 55f37cb6 authored by Oscar Gustafsson's avatar Oscar Gustafsson :bicyclist:
Browse files

Add support for decreasing time resolution

parent d936a145
No related branches found
No related tags found
1 merge request!94Add support for decreasing time resolution
Pipeline #87912 passed
......@@ -231,10 +231,10 @@ class Operation(GraphComponent, SignalSourceProvider):
which ignores the word length specified by the input signal.
The *truncate* parameter specifies whether input truncation should be enabled in the first
place. If set to False, input values will be used directly without any bit truncation.
See also
========
evaluate_outputs, current_output, current_outputs
"""
raise NotImplementedError
......@@ -931,7 +931,13 @@ class AbstractOperation(Operation, AbstractGraphComponent):
self._execution_time *= factor
for port in [*self.inputs, *self.outputs]:
port.latency_offset *= factor
def _decrease_time_resolution(self, factor: int):
if self._execution_time is not None:
self._execution_time = self._execution_time // factor
for port in [*self.inputs, *self.outputs]:
port.latency_offset = port.latency_offset // factor
def get_plot_coordinates(
self,
) -> Tuple[List[List[Number]], List[List[Number]]]:
......
......@@ -5,6 +5,7 @@ Contains the schedule class for scheduling operations in an SFG.
"""
import io
import math
import sys
from collections import defaultdict
from typing import Dict, List, Optional, Tuple
......@@ -191,12 +192,61 @@ class Schedule:
self._start_times = {
k: factor * v for k, v in self._start_times.items()
}
for op_id, op_start_time in self._start_times.items():
for op_id in self._start_times:
self._sfg.find_by_id(op_id)._increase_time_resolution(factor)
self._schedule_time *= factor
return self
def _get_all_times(self) -> List[int]:
"""
Return a list of all times for the schedule. Used to check how the
resolution can be modified.
"""
# Local values
ret = [self._schedule_time, *self._start_times.values()]
# Loop over operations
for op_id in self._start_times:
op = self._sfg.find_by_id(op_id)
ret += [op.execution_time, *op.latency_offsets.values()]
# Remove not set values (None)
ret = [v for v in ret if v is not None]
return ret
def get_possible_time_resolution_decrements(self) -> List[int]:
"""Return a list with possible factors to reduce time resolution."""
vals = self._get_all_times()
maxloop = min(val for val in vals if val)
if maxloop <= 1:
return [1]
ret = [1]
for candidate in range(2, maxloop + 1):
if not any(val % candidate for val in vals):
ret.append(candidate)
return ret
def decrease_time_resolution(self, factor: int) -> "Schedule":
raise NotImplementedError
"""
Decrease time resolution for a schedule.
Parameters
==========
factor : int
The time resolution decrement.
"""
possible_values = self.get_possible_time_resolution_decrements()
if factor not in possible_values:
raise ValueError(
f"Not possible to decrease resolution with {factor}. Possible"
f" values are {possible_values}"
)
self._start_times = {
k: v // factor for k, v in self._start_times.items()
}
for op_id, _ in self._start_times.items():
self._sfg.find_by_id(op_id)._decrease_time_resolution(factor)
self._schedule_time = self._schedule_time // factor
return self
def move_operation(self, op_id: GraphID, time: int) -> "Schedule":
assert (
......
......@@ -305,6 +305,7 @@ class TestTimeResolution:
scheduling_alg="ASAP",
)
old_schedule_time = schedule.schedule_time
assert schedule.get_possible_time_resolution_decrements() == [1]
schedule.increase_time_resolution(2)
......@@ -330,6 +331,7 @@ class TestTimeResolution:
}
assert 2 * old_schedule_time == schedule.schedule_time
assert schedule.get_possible_time_resolution_decrements() == [1, 2]
def test_increase_time_resolution_twice(
self, sfg_two_inputs_two_outputs_independent_with_cmul
......@@ -365,3 +367,72 @@ class TestTimeResolution:
}
assert 6 * old_schedule_time == schedule.schedule_time
assert schedule.get_possible_time_resolution_decrements() == [
1,
2,
3,
6,
]
def test_increase_decrease_time_resolution(
self, sfg_two_inputs_two_outputs_independent_with_cmul
):
schedule = Schedule(
sfg_two_inputs_two_outputs_independent_with_cmul,
scheduling_alg="ASAP",
)
old_schedule_time = schedule.schedule_time
assert schedule.get_possible_time_resolution_decrements() == [1]
schedule.increase_time_resolution(6)
start_times_names = {}
for op_id, start_time in schedule._start_times.items():
op_name = (
sfg_two_inputs_two_outputs_independent_with_cmul.find_by_id(
op_id
).name
)
start_times_names[op_name] = start_time
assert start_times_names == {
"C1": 0,
"IN1": 0,
"IN2": 0,
"CMUL1": 0,
"CMUL2": 30,
"ADD1": 0,
"CMUL3": 42,
"OUT1": 54,
"OUT2": 60,
}
with pytest.raises(
ValueError, match="Not possible to decrease resolution"
):
schedule.decrease_time_resolution(4)
schedule.decrease_time_resolution(3)
start_times_names = {}
for op_id, start_time in schedule._start_times.items():
op_name = (
sfg_two_inputs_two_outputs_independent_with_cmul.find_by_id(
op_id
).name
)
start_times_names[op_name] = start_time
assert start_times_names == {
"C1": 0,
"IN1": 0,
"IN2": 0,
"CMUL1": 0,
"CMUL2": 10,
"ADD1": 0,
"CMUL3": 14,
"OUT1": 18,
"OUT2": 20,
}
assert 2 * old_schedule_time == schedule.schedule_time
assert schedule.get_possible_time_resolution_decrements() == [1, 2]
......@@ -1379,10 +1379,10 @@ class TestPrecedenceGraph:
" shape=square]\n\tcmul1 -> \"cmul1.0\"\n\tcmul1 [label=cmul1"
" shape=square]\n\t\"add1.0\" -> t1In\n\tt1In [label=t1"
" shape=square]\n\tadd1 -> \"add1.0\"\n\tadd1 [label=add1"
" shape=square]\n}\n"
" shape=square]\n}"
)
assert sfg_simple_filter.precedence_graph().source == res
assert sfg_simple_filter.precedence_graph().source in (res, res + "\n")
class TestSFGGraph:
......@@ -1391,20 +1391,20 @@ class TestSFGGraph:
"digraph {\n\trankdir=LR\n\tin1\n\tin1 -> "
"add1\n\tout1\n\tt1 -> out1\n\tadd1\n\tcmul1 -> "
"add1\n\tcmul1\n\tadd1 -> t1\n\tt1 [shape=square]\n\tt1 "
"-> cmul1\n}\n"
"-> cmul1\n}"
)
assert sfg_simple_filter.sfg().source == res
assert sfg_simple_filter.sfg().source in (res, res + "\n")
def test_sfg_show_id(self, sfg_simple_filter):
res = (
"digraph {\n\trankdir=LR\n\tin1\n\tin1 -> add1 "
"[label=s1]\n\tout1\n\tt1 -> out1 [label=s2]\n\tadd1"
"\n\tcmul1 -> add1 [label=s3]\n\tcmul1\n\tadd1 -> t1 "
"[label=s4]\n\tt1 [shape=square]\n\tt1 -> cmul1 [label=s5]\n}\n"
"[label=s4]\n\tt1 [shape=square]\n\tt1 -> cmul1 [label=s5]\n}"
)
assert sfg_simple_filter.sfg(show_id=True).source == res
assert sfg_simple_filter.sfg(show_id=True).source in (res, res + "\n")
def test_show_sfg_invalid_format(self, sfg_simple_filter):
with pytest.raises(ValueError):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment