File size: 2,305 Bytes
1c8d125
 
 
 
96e1a32
 
1c8d125
 
 
 
 
 
 
 
 
 
 
96e1a32
1c8d125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96e1a32
1c8d125
 
 
96e1a32
1c8d125
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import random


def sample_future_length(
    range: tuple[int, int] | str = "gift_eval",
    total_length: int | None = None,
) -> int:
    """
    Sample a forecast length.

    - If `range` is a tuple, uniformly sample in [min, max]. When `total_length` is
      provided, enforce a cap so the result is at most floor(0.45 * total_length).
    - If `range` is "gift_eval", sample from a pre-defined weighted set. When
      `total_length` is provided, filter out candidates greater than
      floor(0.45 * total_length) before sampling.
    """
    # Compute the cap when total_length is provided
    cap: int | None = None
    if total_length is not None:
        cap = max(1, int(0.45 * int(total_length)))

    if isinstance(range, tuple):
        min_len, max_len = range
        if cap is not None:
            effective_max_len = min(max_len, cap)
            # Ensure valid bounds
            if min_len > effective_max_len:
                return effective_max_len
            return random.randint(min_len, effective_max_len)
        return random.randint(min_len, max_len)
    elif range == "gift_eval":
        # Gift eval forecast lengths with their frequencies
        GIFT_EVAL_FORECAST_LENGTHS = {
            48: 5,
            720: 38,
            480: 38,
            30: 3,
            300: 16,
            8: 2,
            120: 3,
            450: 8,
            80: 8,
            12: 2,
            900: 10,
            180: 3,
            600: 10,
            60: 3,
            210: 3,
            195: 3,
            140: 3,
            130: 3,
            14: 1,
            18: 1,
            13: 1,
            6: 1,
        }

        lengths = list(GIFT_EVAL_FORECAST_LENGTHS.keys())
        weights = list(GIFT_EVAL_FORECAST_LENGTHS.values())

        if cap is not None:
            filtered = [
                (length_candidate, weight)
                for length_candidate, weight in zip(lengths, weights, strict=True)
                if length_candidate <= cap
            ]
            if filtered:
                lengths, weights = zip(*filtered, strict=True)
                lengths = list(lengths)
                weights = list(weights)

        return random.choices(lengths, weights=weights)[0]
    else:
        raise ValueError(f"Invalid range: {range}")