diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/Makefile" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/Makefile" new file mode 100644 index 0000000000000000000000000000000000000000..afb1b9459b93dfe8e5bc450f6a3070d5b6fa787c --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/Makefile" @@ -0,0 +1,10 @@ +TARGET := checker +SRC := checker.cc + +all: $(TARGET) + +$(TARGET): $(SRC) + g++ -O2 -o $@ $< + +clean: + rm -f $(TARGET) diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/README.md" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/README.md" new file mode 100644 index 0000000000000000000000000000000000000000..f4ac0a72d17c0a07e69a7dd7da611a7d07ae95e4 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/README.md" @@ -0,0 +1,8 @@ +## Build +Just `make` + +## Run Example + +`./checker example/infile.txt example/outfile.txt example/outfile.txt` + +Where `infile.txt` is the input file, `outfile.txt` is your output file. diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/checker" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/checker" new file mode 100755 index 0000000000000000000000000000000000000000..bf3c6fcaf4bda27008745770ce8c973760d61430 Binary files /dev/null and "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/checker" differ diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/checker.cc" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/checker.cc" new file mode 100644 index 0000000000000000000000000000000000000000..acf43ae3775d78b044c11f4ea507679d99397b32 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/checker.cc" @@ -0,0 +1,194 @@ +#include "testlib.h" + +struct Input { + int L, M, N; + struct Ops { + int addr, size, start, tim; + }; + std::vector ops; +}input; + +struct Output { + struct Ops { + std::string opName; + uint64_t T; + int A, S; + }; + std::vector ops; +}output; + +void check_input() +{ + input.L = inf.readInt(1, 100000, "L"); + input.M = inf.readInt(1, input.L, "M"); + input.N = inf.readInt(1, 10000, "N"); + int last_start = -1; + int N = input.N, L = input.L; + for (int i = 0; i < N; i++) { + int addr = inf.readInt(0, L - 1, "addr_i"); + int size = inf.readInt(1, L, "size_i"); + int start = inf.readInt(0, 1e9, "start_i"); + int tim = inf.readInt(0, 1e9, "time_i"); + quitif(addr + size > L, _fail, "[Invalid Input] addr + size > L, where addr = %d, size = %d, L = %d", addr, size, L); + quitif(start < last_start, _fail, "[Invalid Input] start[%d] > start[%d] (%d > %d)", i - 1, i, last_start, start); + if (start == last_start) { + quitif(input.ops[i - 1].tim != tim, _fail, "[Invalid Input] start[%d] == start[%d] but time[%d] != time[%d]", + i - 1, i, i - 1, i); + } + last_start = start; + input.ops.push_back({addr, size, start, tim}); + } + if (inf.seekEof() == false) { + quitf(_fail, "[Invalid Input] Find task numbers greater than N (which is %d).", N); + } + inf.readEof(); +} + +void check_output() +{ + int last_visit = -1; + uint64_t last_T = 0; + int totalReadLines = 0; + bool finish = false; + std::vector duplicate(input.N, false); + + while (!finish) { + std::string opName = ouf.readWord(); + uint64_t T; + int A, S = 0; + totalReadLines++; + if (totalReadLines > 10 * input.N) { + quitf(_pe, "[Invalid Output] Too many lines. %d > 10n", totalReadLines); + } + if (opName == "Reload") { + T = ouf.readLong(0, (long long)1e18, "T"); + A = ouf.readInt(0, input.L - 1, "A"); + S = ouf.readInt(1, input.L, "S"); + quitif(A + S > input.L, _wa, "[Invalid Output] Reload addr + size > L, where addr = %d, size = %d, L = %d", A, S, input.L); + } else if (opName == "Visit") { + T = ouf.readLong(0, (long long)1e18, "T"); + A = ouf.readInt(0, input.N - 1, "A"); + quitif(duplicate[A], _wa, "[Invalid Output] Output (Visit %d) more than once.", A); + duplicate[A] = true; + if (last_visit != -1 && input.ops[last_visit].start == input.ops[A].start) { + // PASS. We do not care the Visit order in the same task. + } else { + quitif(last_visit >= A, _wa, "[Invalid Output] Tasks are not finished by input sequence. %d >= %d. Op[%d] = (%s %llu %d)", + last_visit, A, totalReadLines - 1, opName.c_str(), T, A); + } + last_visit = A; + } else if (opName == "Offload") { + T = ouf.readLong(0, (long long)1e18, "T"); + A = ouf.readInt(0, input.L - 1, "A"); + S = ouf.readInt(1, input.L, "S"); + quitif(A + S > input.L, _wa, "[Invalid Output] Offload addr + size > L, where addr = %d, size = %d, L = %d.", A, S, input.L); + } else if (opName == "Fin") { + T = ouf.readLong(0, (long long)1e18, "T"); + finish = true; + } else { + quitf(_wa, "[Invalid Output] Unknown operator %s", opName.c_str()); + } + quitif(last_T > T, _wa, "[Invalid Output] Output T is not ascending. %llu > %llu. Op[%d] = (%s %llu ...)", + last_T, T, totalReadLines - 1, opName.c_str(), T); + last_T = T; + output.ops.push_back({opName, T, A, S}); + } + int finish_tasks = 0; + for (int i = 0; i < input.N; i++) { + if (duplicate[i]) { + finish_tasks++; + } + } + quitif(finish_tasks != input.N, _wa, "[Invalid Output] Output did not finish all Visit tasks. %d != N (which is %d)", finish_tasks, input.N); + if (ouf.seekEof() == false) { + quitf(_pe, "[Invalid Output] Find extra lines after output Fin."); + } + ouf.readEof(); +} + +uint64_t get_score() +{ + const int MEM_OFFLOAD = -2; + const int MEM_RELOAD = -1; + const uint64_t multiple_IO = 40; + std::vector in_mem(input.L, MEM_OFFLOAD); + std::vector task_finish_at(input.N, 0); + int use_mem = 0; + int nr_output = output.ops.size(); + int last_visit = -1; + uint64_t score = 0; + uint64_t io_time = 0, npu_time = 0; + for (int i = 0; i < nr_output; i++) { + auto curr = output.ops[i]; + if (curr.opName == "Reload") { + quitif(io_time > curr.T, _wa, "[Invalid Output] IO is busy. Last IO task finish at %llu. Op[%d] = (%s %llu %d %d)", + io_time, i, curr.opName.c_str(), curr.T, curr.A, curr.S); + int cnt_tomem = 0; + for (int j = curr.A; j < curr.A + curr.S; j++) { + if (in_mem[j] == MEM_OFFLOAD) { + cnt_tomem++; + in_mem[j] = MEM_RELOAD; + } + } + use_mem += cnt_tomem; + quitif(use_mem > input.M, _wa, "[Invalid Output] Out of Memory. use_mem = %d, M = %d. Op[%d] = (%s %llu %d %d)", + use_mem, input.M, i, curr.opName.c_str(), curr.T, curr.A, curr.S); + io_time = curr.T + multiple_IO * cnt_tomem; + } else if (curr.opName == "Visit") { + uint64_t this_task_finish_time = curr.T + input.ops[curr.A].tim; + if (last_visit != -1 && input.ops[last_visit].start == input.ops[curr.A].start) { + // The same task. + quitif(this_task_finish_time != npu_time, _wa, "[Invalid Output] Visit %d and Visit %d must start at the same time. Op[%d] = (%s %llu %d)", + last_visit, curr.A, i, curr.opName.c_str(), curr.T, curr.A); + } else { + quitif(npu_time > curr.T, _wa, "[Invalid Output] NPU is busy. Last NPU task finish at %llu. Op[%d] = (%s %llu %d)", + npu_time, i, curr.opName.c_str(), curr.T, curr.A); + npu_time = this_task_finish_time; + } + task_finish_at[curr.A] = this_task_finish_time; + last_visit = curr.A; + quitif(curr.T < input.ops[curr.A].start, _wa, "[Invalid Output] Task %d is not ready. Start time = %llu. Op[%d] = (%s %llu %d)", + curr.A, input.ops[curr.A].start, i, curr.opName.c_str(), curr.T, curr.A); + for (int j = input.ops[curr.A].addr; j < input.ops[curr.A].addr + input.ops[curr.A].size; j++) { + quitif(in_mem[j] == MEM_OFFLOAD, _wa, "[Invalid Output] Addr %d is not in memory. Op[%d] = (%s %llu %d)", + j, i, curr.opName.c_str(), curr.T, curr.A); + if (in_mem[j] == MEM_RELOAD || task_finish_at[in_mem[j]] < task_finish_at[curr.A]) { + in_mem[j] = curr.A; + } + } + } else if (curr.opName == "Offload") { + quitif(io_time > curr.T, _wa, "[Invalid Output] IO is busy. Last IO task finish at %llu. Op[%d] = (%s %llu %d %d)", + io_time, i, curr.opName.c_str(), curr.T, curr.A, curr.S); + int cnt_offmem = 0; + for (int j = curr.A; j < curr.A + curr.S; j++) { + if (in_mem[j] == MEM_OFFLOAD) { + continue; + } + if (in_mem[j] != MEM_RELOAD) { + quitif(curr.T < task_finish_at[in_mem[j]], _wa, "[Invalid Output] Addr %d is used by NPU task %d at T=%llu. Op[%d] = (%s %llu %d %d)", + j, in_mem[j], curr.T, i, curr.opName.c_str(), curr.T, curr.A, curr.S); + } + cnt_offmem++; + in_mem[j] = MEM_OFFLOAD; + } + use_mem -= cnt_offmem; + io_time = curr.T + multiple_IO * cnt_offmem; + } else if (curr.opName == "Fin") { + score = std::max(io_time, npu_time); + quitif(score > curr.T, _wa, "[Invalid Output] Output Fin, but not all the resources are freed. Last IO task finish at %llu, Last NPU task finish at %llu", + io_time, npu_time); + score = curr.T; + } + } + return score; +} + +int main(int argc, char *argv[]) +{ + setName("Global_Memory_Planning_for_LLM"); + registerTestlibCmd(argc, argv); + check_input(); + check_output(); + auto score = get_score(); + quitf(_ok, "All tasks finish at %llu", score); +} \ No newline at end of file diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_01.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_01.txt" new file mode 100644 index 0000000000000000000000000000000000000000..b6381f06dde1d64acefdf54a15477536fc034af0 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_01.txt" @@ -0,0 +1,5 @@ +600 300 4 +0 100 0 20000 +100 100 0 20000 +200 200 1 50 +400 100 2 50 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_02.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_02.txt" new file mode 100644 index 0000000000000000000000000000000000000000..92aee8b1325daabe9e51cf57aa068317baf4b521 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_02.txt" @@ -0,0 +1,4 @@ +300 200 3 +0 100 0 50 +100 100 4000 30 +150 100 4001 20 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_03.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_03.txt" new file mode 100644 index 0000000000000000000000000000000000000000..3a7df150362195a17b17e85ddb01c4e78959c3c2 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_03.txt" @@ -0,0 +1,4 @@ +300 200 3 +0 100 0 5000 +100 100 0 5000 +50 100 4001 20 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_04.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_04.txt" new file mode 100644 index 0000000000000000000000000000000000000000..eb2bd89caf1e9a3d93552b74f64e0c09310030ae --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_04.txt" @@ -0,0 +1,4 @@ +300 200 3 +0 100 500 90000 +100 100 500 90000 +200 10 501 10 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_05.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_05.txt" new file mode 100644 index 0000000000000000000000000000000000000000..b0e48638d9b54c338e9bae5b25825ff5c9363803 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_05.txt" @@ -0,0 +1,3 @@ +200 100 2 +0 100 0 30 +100 100 50 10 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_06.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_06.txt" new file mode 100644 index 0000000000000000000000000000000000000000..2ae99657e7b8ebe00e82f45eeabf7793cebb935a --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_06.txt" @@ -0,0 +1,4 @@ +300 150 3 +0 100 0 50 +100 50 1000 10 +200 50 1000 10 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_07.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_07.txt" new file mode 100644 index 0000000000000000000000000000000000000000..3a7df150362195a17b17e85ddb01c4e78959c3c2 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_07.txt" @@ -0,0 +1,4 @@ +300 200 3 +0 100 0 5000 +100 100 0 5000 +50 100 4001 20 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_08.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_08.txt" new file mode 100644 index 0000000000000000000000000000000000000000..c0bb5531aab179ebdff2c286b3e566f3de6a0faf --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_08.txt" @@ -0,0 +1,4 @@ +200 100 3 +0 100 0 10 +0 100 2000 10 +0 100 4000 10 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_09.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_09.txt" new file mode 100644 index 0000000000000000000000000000000000000000..ab691df6c47397c7d0662b112bd9a6d34f6bad18 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_09.txt" @@ -0,0 +1,4 @@ +400 200 3 +0 150 0 50 +100 150 100 50 +250 150 200 50 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_10.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_10.txt" new file mode 100644 index 0000000000000000000000000000000000000000..a86ac39fe068008e61e2038715d97e4003d3fe36 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_10.txt" @@ -0,0 +1,4 @@ +300 100 3 +0 100 0 100 +100 100 50 50 +200 100 100 50 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_11.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_11.txt" new file mode 100644 index 0000000000000000000000000000000000000000..8c51e820311ef5b3dd4e282fcd364fb972df0319 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_11.txt" @@ -0,0 +1,4 @@ +400 200 3 +0 100 0 10000 +100 100 100 50 +200 100 200 50 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_12.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_12.txt" new file mode 100644 index 0000000000000000000000000000000000000000..a73f829f1aa8ea8b6d3205096ae6ef43ccf044df --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_12.txt" @@ -0,0 +1,5 @@ +500 200 4 +0 150 0 100 +100 150 100 100 +50 150 200 100 +200 150 300 100 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_13.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_13.txt" new file mode 100644 index 0000000000000000000000000000000000000000..11cd175f3ea88c3026e8eae99bc15e28ce0d58d9 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_13.txt" @@ -0,0 +1,4 @@ +300 300 3 +0 100 0 30 +100 100 50 20 +200 100 100 10 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_14.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_14.txt" new file mode 100644 index 0000000000000000000000000000000000000000..ba0c403e27c48d6b45a197987d5113991d63e6ad --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_14.txt" @@ -0,0 +1,5 @@ +400 150 4 +0 100 0 200 +100 50 100 50 +200 100 300 50 +250 50 500 50 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_15.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_15.txt" new file mode 100644 index 0000000000000000000000000000000000000000..0edc9253e7ccdbbebc95c337607d19c7cdc7da14 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_15.txt" @@ -0,0 +1,6 @@ +400 150 5 +0 100 0 20 +50 100 5 20 +150 50 30 10 +200 100 50 10 +300 50 100 10 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_16.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_16.txt" new file mode 100644 index 0000000000000000000000000000000000000000..5efa9aaa2142a182e6b2214d1879b1a90f5d2695 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_16.txt" @@ -0,0 +1,3 @@ +400 150 2 +0 150 0 10000 +150 150 2000 10 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_17.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_17.txt" new file mode 100644 index 0000000000000000000000000000000000000000..299fe98dea00b424b53cbf1f3da0f92998b8dd9c --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_17.txt" @@ -0,0 +1,5 @@ +300 150 4 +0 100 0 50 +100 50 0 50 +200 100 100 50 +250 50 100 50 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_18.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_18.txt" new file mode 100644 index 0000000000000000000000000000000000000000..8212ab7563c3065deb711f65f0662bf456456bf5 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_18.txt" @@ -0,0 +1,4 @@ +400 200 3 +0 100 0 5 +100 100 10 5 +200 100 20 5 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_19.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_19.txt" new file mode 100644 index 0000000000000000000000000000000000000000..d53567109f09d69aef35f9a17d925a7acccef4a4 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_19.txt" @@ -0,0 +1,4 @@ +600 200 3 +0 100 0 50 +300 100 100 50 +500 100 200 50 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_20.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_20.txt" new file mode 100644 index 0000000000000000000000000000000000000000..4b2c5e607c24948487dfda1b19bdab5e6edc6288 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_20.txt" @@ -0,0 +1,5 @@ +500 150 4 +0 100 0 50 +50 100 50 50 +100 100 100 50 +150 100 150 50 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_21.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_21.txt" new file mode 100644 index 0000000000000000000000000000000000000000..c8920876aa6d59fa6d6dd707c43d6c9e833a64c8 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_21.txt" @@ -0,0 +1,4 @@ +400 200 3 +0 100 0 10 +100 100 0 10 +200 100 100 10 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_22.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_22.txt" new file mode 100644 index 0000000000000000000000000000000000000000..7d86ee29971d2289698a46ea3ed64de4956d8082 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_22.txt" @@ -0,0 +1,4 @@ +400 150 3 +0 100 0 200 +100 50 100 10 +200 100 300 10 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_23.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_23.txt" new file mode 100644 index 0000000000000000000000000000000000000000..47b69ecfa97a167b4ae46478ece3076e06930e1d --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_23.txt" @@ -0,0 +1,5 @@ +300 100 4 +0 100 0 30 +50 100 30 30 +100 100 60 30 +150 100 90 30 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_24.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_24.txt" new file mode 100644 index 0000000000000000000000000000000000000000..5f6d55fc284e1002498b2cdb8e3ccc086f8e5dcd --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_24.txt" @@ -0,0 +1,5 @@ +400 150 4 +0 100 0 50 +100 50 10 20 +0 50 100 20 +150 100 200 20 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_25.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_25.txt" new file mode 100644 index 0000000000000000000000000000000000000000..16765657a2b89339fc393fdd1ecee62d82476933 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_25.txt" @@ -0,0 +1,4 @@ +400 200 3 +100 200 0 50 +0 100 100 50 +100 100 200 50 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_26.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_26.txt" new file mode 100644 index 0000000000000000000000000000000000000000..e7a500aa83d6d2f6c65f79f96e5692e5f6d992b6 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_26.txt" @@ -0,0 +1,4 @@ +300 200 3 +0 200 0 4000 +100 100 0 4000 +200 100 5001 50 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_27.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_27.txt" new file mode 100644 index 0000000000000000000000000000000000000000..8a40995cdadd466784abfd5b616250710fe08747 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_27.txt" @@ -0,0 +1,4 @@ +300 200 3 +0 100 0 1000 +100 10 0 1000 +0 10 10001 50 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_28.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_28.txt" new file mode 100644 index 0000000000000000000000000000000000000000..041a75613cfae3f904bff9d3faea4c62fac478e0 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_28.txt" @@ -0,0 +1,4 @@ +300 200 3 +0 100 0 100 +100 100 0 100 +0 101 0 100 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_29.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_29.txt" new file mode 100644 index 0000000000000000000000000000000000000000..eec290a030fa3e9018f13bd2c9c921b801c82f3f --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_29.txt" @@ -0,0 +1,4 @@ +200 200 3 +0 100 0 10 +100 100 0 10 +0 10 1 10000 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_30.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_30.txt" new file mode 100644 index 0000000000000000000000000000000000000000..cff4a8988686c06524cb9f4566375e95c7ab981f --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_30.txt" @@ -0,0 +1,5 @@ +1000 500 4 +0 50 0 100000 +100 100 1 10 +200 100 2 10 +300 100 3 10 diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_31.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_31.txt" new file mode 100644 index 0000000000000000000000000000000000000000..2fd319972b9610d16ca82430dbe17b3c09ed8397 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_31.txt" @@ -0,0 +1,4 @@ +300 300 3 +0 10 0 10000 +10 10 10001 10 +20 200 10002 10 \ No newline at end of file diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_32.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_32.txt" new file mode 100644 index 0000000000000000000000000000000000000000..40caed4791db632e13c77859b8a40f0d2b65e032 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_32.txt" @@ -0,0 +1,6 @@ +2000 1001 5 +0 1000 0 10 +1000 1 0 10 +1001 10 100 10 +1000 1 200 10 +0 1000 300 10 \ No newline at end of file diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_33.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_33.txt" new file mode 100644 index 0000000000000000000000000000000000000000..3440eaf321080576b719debaa3d02758fcfed5d9 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_33.txt" @@ -0,0 +1,4 @@ +10000 10000 3 +0 100 0 10000 +100 100 10000 100 +200 5000 10100 10 \ No newline at end of file diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_34.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_34.txt" new file mode 100644 index 0000000000000000000000000000000000000000..b8e379623b2214b5b1c0fbf8bced9cd16b03ff55 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_34.txt" @@ -0,0 +1,13 @@ +100000 100000 12 +0 100 0 4000 +100 10 4000 4000 +110 10 8000 4000 +120 10 12000 4000 +130 10 16000 4000 +140 10 20000 4000 +150 10 24000 4000 +160 10 28000 4000 +170 10 32000 4000 +180 10 36000 4000 +5000 800 40000 10 +0 100 100000 10 \ No newline at end of file diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_35.txt" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_35.txt" new file mode 100644 index 0000000000000000000000000000000000000000..69fca2d7da9e64476b4bef02767a7140f0a828c1 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/example/infile/infile_35.txt" @@ -0,0 +1,5 @@ +1000 100 4 +0 90 0 10000 +100 10 1 4999 +200 10 4000 10 +100 10 20000 10 \ No newline at end of file diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/testlib.h" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/testlib.h" new file mode 100644 index 0000000000000000000000000000000000000000..4e1d3682e1a0e7c2000ed9c2b24587062646fe91 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/checker/testlib.h" @@ -0,0 +1,6299 @@ +/* + * It is strictly recommended to include "testlib.h" before any other include + * in your code. In this case testlib overrides compiler specific "random()". + * + * If you can't compile your code and compiler outputs something about + * ambiguous call of "random_shuffle", "rand" or "srand" it means that + * you shouldn't use them. Use "shuffle", and "rnd.next()" instead of them + * because these calls produce stable result for any C++ compiler. Read + * sample generator sources for clarification. + * + * Please read the documentation for class "random_t" and use "rnd" instance in + * generators. Probably, these sample calls will be useful for you: + * rnd.next(); rnd.next(100); rnd.next(1, 2); + * rnd.next(3.14); rnd.next("[a-z]{1,100}"). + * + * Also read about wnext() to generate off-center random distribution. + * + * See https://github.com/MikeMirzayanov/testlib/ to get latest version or bug tracker. + */ + +#ifndef _TESTLIB_H_ +#define _TESTLIB_H_ + +/* + * Copyright (c) 2005-2024 + */ + +#define VERSION "0.9.44" + +/* + * Mike Mirzayanov + * + * This material is provided "as is", with absolutely no warranty expressed + * or implied. Any use is at your own risk. + * + * Permission to use or copy this software for any purpose is hereby granted + * without fee, provided the above notices are retained on all copies. + * Permission to modify the code and to distribute modified code is granted, + * provided the above notices are retained, and a notice that the code was + * modified is included with the above copyright notice. + * + */ + +/* NOTE: This file contains testlib library for C++. + * + * Check, using testlib running format: + * check.exe [ [-appes]], + * If result file is specified it will contain results. + * + * Validator, using testlib running format: + * validator.exe < input.txt, + * It will return non-zero exit code and writes message to standard output. + * + * Generator, using testlib running format: + * gen.exe [parameter-1] [parameter-2] [... paramerter-n] + * You can write generated test(s) into standard output or into the file(s). + * + * Interactor, using testlib running format: + * interactor.exe [ [ [-appes]]], + * Reads test from inf (mapped to args[1]), writes result to tout (mapped to argv[2], + * can be judged by checker later), reads program output from ouf (mapped to stdin), + * writes output to program via stdout (use cout, printf, etc). + */ + +const char *latestFeatures[] = { + "Added ConstantBoundsLog, VariablesLog to validator testOverviewLogFile", + "Use setAppesModeEncoding to change xml encoding from windows-1251 to other", + "rnd.any/wany use distance/advance instead of -/+: now they support sets/multisets", + "Use syntax `int t = inf.readInt(1, 3, \"~t\");` to skip the lower bound check. Tildes can be used on either side or both: ~t, t~, ~t~", + "Supported EJUDGE support in registerTestlibCmd", + "Supported '--testMarkupFileName fn' and '--testCase tc/--testCaseFileName fn' for validators", + "Added opt defaults via opt(key/index, default_val); check unused opts when using has_opt or default opt (turn off this check with suppressEnsureNoUnusedOpt()).", + "For checker added --group and --testset command line params (like for validator), use checker.group() or checker.testset() to get values", + "Added quitpi(points_info, message) function to return with _points exit code 7 and given points_info", + "rnd.partition(size, sum[, min_part=1]) returns random (unsorted) partition which is a representation of the given `sum` as a sum of `size` positive integers (or >=min_part if specified)", + "rnd.distinct(size, n) and rnd.distinct(size, from, to)", + "opt(\"some_missing_key\") returns false now", + "has_opt(key)", + "Abort validator on validator.testset()/validator.group() if registered without using command line", + "Print integer range violations in a human readable way like `violates the range [1, 10^9]`", + "Opts supported: use them like n = opt(\"n\"), in a command line you can use an exponential notation", + "Reformatted", + "Use setTestCase(i) or unsetTestCase() to support test cases (you can use it in any type of program: generator, interactor, validator or checker)", + "Fixed issue #87: readStrictDouble accepts \"-0.00\"", + "Fixed issue #83: added InStream::quitif(condition, ...)", + "Fixed issue #79: fixed missed guard against repeated header include", + "Fixed issue #80: fixed UB in case of huge quitf message", + "Fixed issue #84: added readXs(size, indexBase = 1)", + "Fixed stringstream repeated usage issue", + "Fixed compilation in g++ (for std=c++03)", + "Batch of println functions (support collections, iterator ranges)", + "Introduced rnd.perm(size, first = 0) to generate a `first`-indexed permutation", + "Allow any whitespace in readInts-like functions for non-validators", + "Ignore 4+ command line arguments ifdef EJUDGE", + "Speed up of vtos", + "Show line number in validators in case of incorrect format", + "Truncate huge checker/validator/interactor message", + "Fixed issue with readTokenTo of very long tokens, now aborts with _pe/_fail depending of a stream type", + "Introduced InStream::ensure/ensuref checking a condition, returns wa/fail depending of a stream type", + "Fixed compilation in VS 2015+", + "Introduced space-separated read functions: readWords/readTokens, multilines read functions: readStrings/readLines", + "Introduced space-separated read functions: readInts/readIntegers/readLongs/readUnsignedLongs/readDoubles/readReals/readStrictDoubles/readStrictReals", + "Introduced split/tokenize functions to separate string by given char", + "Introduced InStream::readUnsignedLong and InStream::readLong with unsigned long long parameters", + "Supported --testOverviewLogFileName for validator: bounds hits + features", + "Fixed UB (sequence points) in random_t", + "POINTS_EXIT_CODE returned back to 7 (instead of 0)", + "Removed disable buffers for interactive problems, because it works unexpectedly in wine", + "InStream over string: constructor of InStream from base InStream to inherit policies and std::string", + "Added expectedButFound quit function, examples: expectedButFound(_wa, 10, 20), expectedButFound(_fail, ja, pa, \"[n=%d,m=%d]\", n, m)", + "Fixed incorrect interval parsing in patterns", + "Use registerGen(argc, argv, 1) to develop new generator, use registerGen(argc, argv, 0) to compile old generators (originally created for testlib under 0.8.7)", + "Introduced disableFinalizeGuard() to switch off finalization checkings", + "Use join() functions to format a range of items as a single string (separated by spaces or other separators)", + "Use -DENABLE_UNEXPECTED_EOF to enable special exit code (by default, 8) in case of unexpected eof. It is good idea to use it in interactors", + "Use -DUSE_RND_AS_BEFORE_087 to compile in compatibility mode with random behavior of versions before 0.8.7", + "Fixed bug with nan in stringToDouble", + "Fixed issue around overloads for size_t on x64", + "Added attribute 'points' to the XML output in case of result=_points", + "Exit codes can be customized via macros, e.g. -DPE_EXIT_CODE=14", + "Introduced InStream function readWordTo/readTokenTo/readStringTo/readLineTo for faster reading", + "Introduced global functions: format(), englishEnding(), upperCase(), lowerCase(), compress()", + "Manual buffer in InStreams, some IO speed improvements", + "Introduced quitif(bool, const char* pattern, ...) which delegates to quitf() in case of first argument is true", + "Introduced guard against missed quitf() in checker or readEof() in validators", + "Supported readStrictReal/readStrictDouble - to use in validators to check strictly float numbers", + "Supported registerInteraction(argc, argv)", + "Print checker message to the stderr instead of stdout", + "Supported TResult _points to output calculated score, use quitp(...) functions", + "Fixed to be compilable on Mac", + "PC_BASE_EXIT_CODE=50 in case of defined TESTSYS", + "Fixed issues 19-21, added __attribute__ format printf", + "Some bug fixes", + "ouf.readInt(1, 100) and similar calls return WA", + "Modified random_t to avoid integer overflow", + "Truncated checker output [patch by Stepan Gatilov]", + "Renamed class random -> class random_t", + "Supported name parameter for read-and-validation methods, like readInt(1, 2, \"n\")", + "Fixed bug in readDouble()", + "Improved ensuref(), fixed nextLine to work in case of EOF, added startTest()", + "Supported \"partially correct\", example: quitf(_pc(13), \"result=%d\", result)", + "Added shuffle(begin, end), use it instead of random_shuffle(begin, end)", + "Added readLine(const string& ptrn), fixed the logic of readLine() in the validation mode", + "Package extended with samples of generators and validators", + "Written the documentation for classes and public methods in testlib.h", + "Implemented random routine to support generators, use registerGen() to switch it on", + "Implemented strict mode to validate tests, use registerValidation() to switch it on", + "Now ncmp.cpp and wcmp.cpp are return WA if answer is suffix or prefix of the output", + "Added InStream::readLong() and removed InStream::readLongint()", + "Now no footer added to each report by default (use directive FOOTER to switch on)", + "Now every checker has a name, use setName(const char* format, ...) to set it", + "Now it is compatible with TTS (by Kittens Computing)", + "Added \'ensure(condition, message = \"\")\' feature, it works like assert()", + "Fixed compatibility with MS C++ 7.1", + "Added footer with exit code information", + "Added compatibility with EJUDGE (compile with EJUDGE directive)", + "Added compatibility with Contester (compile with CONTESTER directive)" +}; + +#ifdef _MSC_VER +#define _CRT_SECURE_NO_DEPRECATE +#define _CRT_SECURE_NO_WARNINGS +#define _CRT_NO_VA_START_VALIDATION +#endif + +/* Overrides random() for Borland C++. */ +#define random __random_deprecated +#include +#include +#include +#include +#undef random + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef TESTLIB_THROW_EXIT_EXCEPTION_INSTEAD_OF_EXIT +# include +#endif + +#if (_WIN32 || __WIN32__ || __WIN32 || _WIN64 || __WIN64__ || __WIN64 || WINNT || __WINNT || __WINNT__ || __CYGWIN__) +# if !defined(_MSC_VER) || _MSC_VER > 1400 +# define NOMINMAX 1 +# include +# else +# define WORD unsigned short +# include +# endif +# include +# define ON_WINDOWS +# if defined(_MSC_VER) && _MSC_VER > 1400 +# pragma warning( disable : 4127 ) +# pragma warning( disable : 4146 ) +# pragma warning( disable : 4458 ) +# endif +#else +# define WORD unsigned short +# include +#endif + +#if defined(FOR_WINDOWS) && defined(FOR_LINUX) +#error Only one target system is allowed +#endif + +#ifndef LLONG_MIN +#define LLONG_MIN (-9223372036854775807LL - 1) +#endif + +#ifndef ULLONG_MAX +#define ULLONG_MAX (18446744073709551615) +#endif + +#define LF ((char)10) +#define CR ((char)13) +#define TAB ((char)9) +#define SPACE ((char)' ') +#define EOFC (255) + +#ifndef OK_EXIT_CODE +# ifdef CONTESTER +# define OK_EXIT_CODE 0xAC +# else +# define OK_EXIT_CODE 0 +# endif +#endif + +#ifndef WA_EXIT_CODE +# ifdef EJUDGE +# define WA_EXIT_CODE 5 +# elif defined(CONTESTER) +# define WA_EXIT_CODE 0xAB +# else +# define WA_EXIT_CODE 1 +# endif +#endif + +#ifndef PE_EXIT_CODE +# ifdef EJUDGE +# define PE_EXIT_CODE 4 +# elif defined(CONTESTER) +# define PE_EXIT_CODE 0xAA +# else +# define PE_EXIT_CODE 2 +# endif +#endif + +#ifndef FAIL_EXIT_CODE +# ifdef EJUDGE +# define FAIL_EXIT_CODE 6 +# elif defined(CONTESTER) +# define FAIL_EXIT_CODE 0xA3 +# else +# define FAIL_EXIT_CODE 3 +# endif +#endif + +#ifndef DIRT_EXIT_CODE +# ifdef EJUDGE +# define DIRT_EXIT_CODE 6 +# else +# define DIRT_EXIT_CODE 4 +# endif +#endif + +#ifndef POINTS_EXIT_CODE +# define POINTS_EXIT_CODE 7 +#endif + +#ifndef UNEXPECTED_EOF_EXIT_CODE +# define UNEXPECTED_EOF_EXIT_CODE 8 +#endif + +#ifndef PC_BASE_EXIT_CODE +# ifdef TESTSYS +# define PC_BASE_EXIT_CODE 50 +# else +# define PC_BASE_EXIT_CODE 0 +# endif +#endif + +#ifdef __GNUC__ +# define __TESTLIB_STATIC_ASSERT(condition) typedef void* __testlib_static_assert_type[(condition) ? 1 : -1] __attribute__((unused)) +#else +# define __TESTLIB_STATIC_ASSERT(condition) typedef void* __testlib_static_assert_type[(condition) ? 1 : -1] +#endif + +#ifdef ON_WINDOWS +#define I64 "%I64d" +#define U64 "%I64u" +#else +#define I64 "%lld" +#define U64 "%llu" +#endif + +#ifdef _MSC_VER +# define NORETURN __declspec(noreturn) +#elif defined __GNUC__ +# define NORETURN __attribute__ ((noreturn)) +#else +# define NORETURN +#endif + +static char __testlib_format_buffer[16777216]; +static int __testlib_format_buffer_usage_count = 0; + +#define FMT_TO_RESULT(fmt, cstr, result) std::string result; \ + if (__testlib_format_buffer_usage_count != 0) \ + __testlib_fail("FMT_TO_RESULT::__testlib_format_buffer_usage_count != 0"); \ + __testlib_format_buffer_usage_count++; \ + va_list ap; \ + va_start(ap, fmt); \ + vsnprintf(__testlib_format_buffer, sizeof(__testlib_format_buffer), cstr, ap); \ + va_end(ap); \ + __testlib_format_buffer[sizeof(__testlib_format_buffer) - 1] = 0; \ + result = std::string(__testlib_format_buffer); \ + __testlib_format_buffer_usage_count--; \ + +#ifdef __GNUC__ +__attribute__ ((format (printf, 1, 2))) +#endif +std::string testlib_format_(const char *fmt, ...); +std::string testlib_format_(const std::string fmt, ...); + +const long long __TESTLIB_LONGLONG_MAX = 9223372036854775807LL; +const int __TESTLIB_MAX_TEST_CASE = 1073741823; + +int __testlib_exitCode; + +bool __testlib_hasTestCase; +int __testlib_testCase = -1; + +void setTestCase(int testCase); + +void unsetTestCase() { + __testlib_hasTestCase = false; + __testlib_testCase = -1; +} + +NORETURN static void __testlib_fail(const std::string &message); + +template +#ifdef __GNUC__ +__attribute__((const)) +#endif +static inline T __testlib_abs(const T &x) { + return x > 0 ? x : -x; +} + +template +#ifdef __GNUC__ +__attribute__((const)) +#endif +static inline T __testlib_min(const T &a, const T &b) { + return a < b ? a : b; +} + +template +#ifdef __GNUC__ +__attribute__((const)) +#endif +static inline T __testlib_max(const T &a, const T &b) { + return a > b ? a : b; +} + +template +#ifdef __GNUC__ +__attribute__((const)) +#endif +static inline T __testlib_crop(T value, T a, T b) { + return __testlib_min(__testlib_max(value, a), --b); +} + +#ifdef __GNUC__ +__attribute__((const)) +#endif +static inline double __testlib_crop(double value, double a, double b) { + value = __testlib_min(__testlib_max(value, a), b); + if (value >= b) + value = std::nexttoward(b, a); + return value; +} + +static bool __testlib_prelimIsNaN(double r) { + volatile double ra = r; +#ifndef __BORLANDC__ + return ((ra != ra) == true) && ((ra == ra) == false) && ((1.0 > ra) == false) && ((1.0 < ra) == false); +#else + return std::_isnan(ra); +#endif +} + +#ifdef __GNUC__ +__attribute__((const)) +#endif +static std::string removeDoubleTrailingZeroes(std::string value) { + while (!value.empty() && value[value.length() - 1] == '0' && value.find('.') != std::string::npos) + value = value.substr(0, value.length() - 1); + if (!value.empty() && value[value.length() - 1] == '.') + return value + '0'; + else + return value; +} + +#ifdef __GNUC__ +__attribute__((const)) +#endif +inline std::string upperCase(std::string s) { + for (size_t i = 0; i < s.length(); i++) + if ('a' <= s[i] && s[i] <= 'z') + s[i] = char(s[i] - 'a' + 'A'); + return s; +} + +#ifdef __GNUC__ +__attribute__((const)) +#endif +inline std::string lowerCase(std::string s) { + for (size_t i = 0; i < s.length(); i++) + if ('A' <= s[i] && s[i] <= 'Z') + s[i] = char(s[i] - 'A' + 'a'); + return s; +} + +#ifdef __GNUC__ +__attribute__((const)) +#endif +static std::string __testlib_part(const std::string &s); + +static bool __testlib_isNaN(double r) { + __TESTLIB_STATIC_ASSERT(sizeof(double) == sizeof(long long)); + volatile double ra = r; + long long llr1, llr2; + std::memcpy((void *) &llr1, (void *) &ra, sizeof(double)); + ra = -ra; + std::memcpy((void *) &llr2, (void *) &ra, sizeof(double)); + long long llnan = 0xFFF8000000000000LL; + return __testlib_prelimIsNaN(r) || llnan == llr1 || llnan == llr2; +} + +static double __testlib_nan() { + __TESTLIB_STATIC_ASSERT(sizeof(double) == sizeof(long long)); +#ifndef NAN + long long llnan = 0xFFF8000000000000LL; + double nan; + std::memcpy(&nan, &llnan, sizeof(double)); + return nan; +#else + return NAN; +#endif +} + +static bool __testlib_isInfinite(double r) { + volatile double ra = r; + return (ra > 1E300 || ra < -1E300); +} + +#ifdef __GNUC__ +__attribute__((const)) +#endif +inline bool doubleCompare(double expected, double result, double MAX_DOUBLE_ERROR) { + MAX_DOUBLE_ERROR += 1E-15; + if (__testlib_isNaN(expected)) { + return __testlib_isNaN(result); + } else if (__testlib_isInfinite(expected)) { + if (expected > 0) { + return result > 0 && __testlib_isInfinite(result); + } else { + return result < 0 && __testlib_isInfinite(result); + } + } else if (__testlib_isNaN(result) || __testlib_isInfinite(result)) { + return false; + } else if (__testlib_abs(result - expected) <= MAX_DOUBLE_ERROR) { + return true; + } else { + double minv = __testlib_min(expected * (1.0 - MAX_DOUBLE_ERROR), + expected * (1.0 + MAX_DOUBLE_ERROR)); + double maxv = __testlib_max(expected * (1.0 - MAX_DOUBLE_ERROR), + expected * (1.0 + MAX_DOUBLE_ERROR)); + return result >= minv && result <= maxv; + } +} + +#ifdef __GNUC__ +__attribute__((const)) +#endif +inline double doubleDelta(double expected, double result) { + double absolute = __testlib_abs(result - expected); + + if (__testlib_abs(expected) > 1E-9) { + double relative = __testlib_abs(absolute / expected); + return __testlib_min(absolute, relative); + } else + return absolute; +} + +/** It does nothing on non-windows and files differ from stdin/stdout/stderr. */ +static void __testlib_set_binary(std::FILE *file) { + if (NULL != file) { +#ifdef ON_WINDOWS +# ifdef _O_BINARY + if (stdin == file) +# ifdef STDIN_FILENO + return void(_setmode(STDIN_FILENO, _O_BINARY)); +# else + return void(_setmode(_fileno(stdin), _O_BINARY)); +# endif + if (stdout == file) +# ifdef STDOUT_FILENO + return void(_setmode(STDOUT_FILENO, _O_BINARY)); +# else + return void(_setmode(_fileno(stdout), _O_BINARY)); +# endif + if (stderr == file) +# ifdef STDERR_FILENO + return void(_setmode(STDERR_FILENO, _O_BINARY)); +# else + return void(_setmode(_fileno(stderr), _O_BINARY)); +# endif +# elif O_BINARY + if (stdin == file) +# ifdef STDIN_FILENO + return void(setmode(STDIN_FILENO, O_BINARY)); +# else + return void(setmode(fileno(stdin), O_BINARY)); +# endif + if (stdout == file) +# ifdef STDOUT_FILENO + return void(setmode(STDOUT_FILENO, O_BINARY)); +# else + return void(setmode(fileno(stdout), O_BINARY)); +# endif + if (stderr == file) +# ifdef STDERR_FILENO + return void(setmode(STDERR_FILENO, O_BINARY)); +# else + return void(setmode(fileno(stderr), O_BINARY)); +# endif +# endif +#endif + } +} + +#if __cplusplus > 199711L || defined(_MSC_VER) +template +#ifdef __GNUC__ +__attribute__((const)) +#endif +static std::string vtos(const T &t, std::true_type) { + if (t == 0) + return "0"; + else { + T n(t); + bool negative = n < 0; + std::string s; + while (n != 0) { + T digit = n % 10; + if (digit < 0) + digit = -digit; + s += char('0' + digit); + n /= 10; + } + std::reverse(s.begin(), s.end()); + return negative ? "-" + s : s; + } +} + +template +static std::string vtos(const T &t, std::false_type) { + std::string s; + static std::stringstream ss; + ss.str(std::string()); + ss.clear(); + ss << t; + ss >> s; + return s; +} + +template +static std::string vtos(const T &t) { + return vtos(t, std::is_integral()); +} + +/* signed case. */ +template +static std::string toHumanReadableString(const T &n, std::false_type) { + if (n == 0) + return vtos(n); + int trailingZeroCount = 0; + T n_ = n; + while (n_ % 10 == 0) + n_ /= 10, trailingZeroCount++; + if (trailingZeroCount >= 7) { + if (n_ == 1) + return "10^" + vtos(trailingZeroCount); + else if (n_ == -1) + return "-10^" + vtos(trailingZeroCount); + else + return vtos(n_) + "*10^" + vtos(trailingZeroCount); + } else + return vtos(n); +} + +/* unsigned case. */ +template +static std::string toHumanReadableString(const T &n, std::true_type) { + if (n == 0) + return vtos(n); + int trailingZeroCount = 0; + T n_ = n; + while (n_ % 10 == 0) + n_ /= 10, trailingZeroCount++; + if (trailingZeroCount >= 7) { + if (n_ == 1) + return "10^" + vtos(trailingZeroCount); + else + return vtos(n_) + "*10^" + vtos(trailingZeroCount); + } else + return vtos(n); +} + +template +static std::string toHumanReadableString(const T &n) { + return toHumanReadableString(n, std::is_unsigned()); +} +#else +template +static std::string vtos(const T& t) +{ + std::string s; + static std::stringstream ss; + ss.str(std::string()); + ss.clear(); + ss << t; + ss >> s; + return s; +} + +template +static std::string toHumanReadableString(const T &n) { + return vtos(n); +} +#endif + +template +static std::string toString(const T &t) { + return vtos(t); +} + +#if __cplusplus > 199711L || defined(_MSC_VER) +/* opts */ +void prepareOpts(int argc, char* argv[]); +#endif + +FILE* testlib_fopen_(const char* path, const char* mode) { +#ifdef _MSC_VER + FILE* result = NULL; + if (fopen_s(&result, path, mode) != 0) + return NULL; + else + return result; +#else + return std::fopen(path, mode); +#endif +} + +FILE* testlib_freopen_(const char* path, const char* mode, FILE* file) { +#ifdef _MSC_VER + FILE* result = NULL; + if (freopen_s(&result, path, mode, file) != 0) + return NULL; + else + return result; +#else + return std::freopen(path, mode, file); +#endif +} + +/* + * Very simple regex-like pattern. + * It used for two purposes: validation and generation. + * + * For example, pattern("[a-z]{1,5}").next(rnd) will return + * random string from lowercase latin letters with length + * from 1 to 5. It is easier to call rnd.next("[a-z]{1,5}") + * for the same effect. + * + * Another samples: + * "mike|john" will generate (match) "mike" or "john"; + * "-?[1-9][0-9]{0,3}" will generate (match) non-zero integers from -9999 to 9999; + * "id-([ac]|b{2})" will generate (match) "id-a", "id-bb", "id-c"; + * "[^0-9]*" will match sequences (empty or non-empty) without digits, you can't + * use it for generations. + * + * You can't use pattern for generation if it contains meta-symbol '*'. Also it + * is not recommended to use it for char-sets with meta-symbol '^' like [^a-z]. + * + * For matching very simple greedy algorithm is used. For example, pattern + * "[0-9]?1" will not match "1", because of greedy nature of matching. + * Alternations (meta-symbols "|") are processed with brute-force algorithm, so + * do not use many alternations in one expression. + * + * If you want to use one expression many times it is better to compile it into + * a single pattern like "pattern p("[a-z]+")". Later you can use + * "p.matches(std::string s)" or "p.next(random_t& rd)" to check matching or generate + * new string by pattern. + * + * Simpler way to read token and check it for pattern matching is "inf.readToken("[a-z]+")". + * + * All spaces are ignored in regex, unless escaped with \. For example, ouf.readLine("NO SOLUTION") + * will expect "NOSOLUTION", the correct call should be ouf.readLine("NO\\ SOLUTION") or + * ouf.readLine(R"(NO\ SOLUTION)") if you prefer raw string literals from C++11. + */ +class random_t; + +class pattern { +public: + /* Create pattern instance by string. */ + pattern(std::string s); + + /* Generate new string by pattern and given random_t. */ + std::string next(random_t &rnd) const; + + /* Checks if given string match the pattern. */ + bool matches(const std::string &s) const; + + /* Returns source string of the pattern. */ + std::string src() const; + +private: + bool matches(const std::string &s, size_t pos) const; + + std::string s; + std::vector children; + std::vector chars; + int from; + int to; +}; + +/* + * Use random_t instances to generate random values. It is preferred + * way to use randoms instead of rand() function or self-written + * randoms. + * + * Testlib defines global variable "rnd" of random_t class. + * Use registerGen(argc, argv, 1) to setup random_t seed be command + * line (to use latest random generator version). + * + * Random generates uniformly distributed values if another strategy is + * not specified explicitly. + */ +class random_t { +private: + unsigned long long seed; + static const unsigned long long multiplier; + static const unsigned long long addend; + static const unsigned long long mask; + static const int lim; + + long long nextBits(int bits) { + if (bits <= 48) { + seed = (seed * multiplier + addend) & mask; + return (long long) (seed >> (48 - bits)); + } else { + if (bits > 63) + __testlib_fail("random_t::nextBits(int bits): n must be less than 64"); + + int lowerBitCount = (random_t::version == 0 ? 31 : 32); + + long long left = (nextBits(31) << 32); + long long right = nextBits(lowerBitCount); + + return left ^ right; + } + } + +public: + static int version; + + /* New random_t with fixed seed. */ + random_t() + : seed(3905348978240129619LL) { + } + + /* Sets seed by command line. */ + void setSeed(int argc, char *argv[]) { + random_t p; + + seed = 3905348978240129619LL; + for (int i = 1; i < argc; i++) { + std::size_t le = std::strlen(argv[i]); + for (std::size_t j = 0; j < le; j++) + seed = seed * multiplier + (unsigned int) (argv[i][j]) + addend; + seed += multiplier / addend; + } + + seed = seed & mask; + } + + /* Sets seed by given value. */ + void setSeed(long long _seed) { + seed = (unsigned long long) _seed; + seed = (seed ^ multiplier) & mask; + } + +#ifndef __BORLANDC__ + + /* Random string value by given pattern (see pattern documentation). */ + std::string next(const std::string &ptrn) { + pattern p(ptrn); + return p.next(*this); + } + +#else + /* Random string value by given pattern (see pattern documentation). */ + std::string next(std::string ptrn) + { + pattern p(ptrn); + return p.next(*this); + } +#endif + + /* Random value in range [0, n-1]. */ + int next(int n) { + if (n <= 0) + __testlib_fail("random_t::next(int n): n must be positive"); + + if ((n & -n) == n) // n is a power of 2 + return (int) ((n * (long long) nextBits(31)) >> 31); + + const long long limit = INT_MAX / n * n; + + long long bits; + do { + bits = nextBits(31); + } while (bits >= limit); + + return int(bits % n); + } + + /* Random value in range [0, n-1]. */ + unsigned int next(unsigned int n) { + if (n >= INT_MAX) + __testlib_fail("random_t::next(unsigned int n): n must be less INT_MAX"); + return (unsigned int) next(int(n)); + } + + /* Random value in range [0, n-1]. */ + long long next(long long n) { + if (n <= 0) + __testlib_fail("random_t::next(long long n): n must be positive"); + + const long long limit = __TESTLIB_LONGLONG_MAX / n * n; + + long long bits; + do { + bits = nextBits(63); + } while (bits >= limit); + + return bits % n; + } + + /* Random value in range [0, n-1]. */ + unsigned long long next(unsigned long long n) { + if (n >= (unsigned long long) (__TESTLIB_LONGLONG_MAX)) + __testlib_fail("random_t::next(unsigned long long n): n must be less LONGLONG_MAX"); + return (unsigned long long) next((long long) (n)); + } + + /* Random value in range [0, n-1]. */ + long next(long n) { + return (long) next((long long) (n)); + } + + /* Random value in range [0, n-1]. */ + unsigned long next(unsigned long n) { + if (n >= (unsigned long) (LONG_MAX)) + __testlib_fail("random_t::next(unsigned long n): n must be less LONG_MAX"); + return (unsigned long) next((unsigned long long) (n)); + } + + /* Returns random value in range [from,to]. */ + int next(int from, int to) { + return int(next((long long) to - from + 1) + from); + } + + /* Returns random value in range [from,to]. */ + unsigned int next(unsigned int from, unsigned int to) { + return (unsigned int) (next((long long) to - from + 1) + from); + } + + /* Returns random value in range [from,to]. */ + long long next(long long from, long long to) { + return next(to - from + 1) + from; + } + + /* Returns random value in range [from,to]. */ + unsigned long long next(unsigned long long from, unsigned long long to) { + if (from > to) + __testlib_fail("random_t::next(unsigned long long from, unsigned long long to): from can't not exceed to"); + return next(to - from + 1) + from; + } + + /* Returns random value in range [from,to]. */ + long next(long from, long to) { + return next(to - from + 1) + from; + } + + /* Returns random value in range [from,to]. */ + unsigned long next(unsigned long from, unsigned long to) { + if (from > to) + __testlib_fail("random_t::next(unsigned long from, unsigned long to): from can't not exceed to"); + return next(to - from + 1) + from; + } + + /* Random double value in range [0, 1). */ + double next() { + long long left = ((long long) (nextBits(26)) << 27); + long long right = nextBits(27); + return __testlib_crop((double) (left + right) / (double) (1LL << 53), 0.0, 1.0); + } + + /* Random double value in range [0, n). */ + double next(double n) { + if (n <= 0.0) + __testlib_fail("random_t::next(double): n should be positive"); + return __testlib_crop(n * next(), 0.0, n); + } + + /* Random double value in range [from, to). */ + double next(double from, double to) { + if (from >= to) + __testlib_fail("random_t::next(double from, double to): from should be strictly less than to"); + return next(to - from) + from; + } + + /* Returns random element from container. */ + template + typename Container::value_type any(const Container &c) { + int size = int(c.size()); + if (size <= 0) + __testlib_fail("random_t::any(const Container& c): c.size() must be positive"); + typename Container::const_iterator it = c.begin(); + std::advance(it, next(size)); + return *it; + } + + /* Returns random element from iterator range. */ + template + typename Iter::value_type any(const Iter &begin, const Iter &end) { + int size = static_cast(std::distance(begin, end)); + if (size <= 0) + __testlib_fail("random_t::any(const Iter& begin, const Iter& end): range must have positive length"); + Iter it = begin; + std::advance(it, next(size)); + return *it; + } + + /* Random string value by given pattern (see pattern documentation). */ +#ifdef __GNUC__ + __attribute__ ((format (printf, 2, 3))) +#endif + std::string next(const char *format, ...) { + FMT_TO_RESULT(format, format, ptrn); + return next(ptrn); + } + + /* + * Weighted next. If type == 0 than it is usual "next()". + * + * If type = 1, than it returns "max(next(), next())" + * (the number of "max" functions equals to "type"). + * + * If type < 0, than "max" function replaces with "min". + */ + int wnext(int n, int type) { + if (n <= 0) + __testlib_fail("random_t::wnext(int n, int type): n must be positive"); + + if (abs(type) < random_t::lim) { + int result = next(n); + + for (int i = 0; i < +type; i++) + result = __testlib_max(result, next(n)); + + for (int i = 0; i < -type; i++) + result = __testlib_min(result, next(n)); + + return result; + } else { + double p; + + if (type > 0) + p = std::pow(next() + 0.0, 1.0 / (type + 1)); + else + p = 1 - std::pow(next() + 0.0, 1.0 / (-type + 1)); + + return __testlib_crop((int) (double(n) * p), 0, n); + } + } + + /* See wnext(int, int). It uses the same algorithms. */ + long long wnext(long long n, int type) { + if (n <= 0) + __testlib_fail("random_t::wnext(long long n, int type): n must be positive"); + + if (abs(type) < random_t::lim) { + long long result = next(n); + + for (int i = 0; i < +type; i++) + result = __testlib_max(result, next(n)); + + for (int i = 0; i < -type; i++) + result = __testlib_min(result, next(n)); + + return result; + } else { + double p; + + if (type > 0) + p = std::pow(next() + 0.0, 1.0 / (type + 1)); + else + p = 1 - std::pow(next() + 0.0, 1.0 / (-type + 1)); + + return __testlib_crop((long long) (double(n) * p), 0LL, n); + } + } + + /* Returns value in [0, n). See wnext(int, int). It uses the same algorithms. */ + double wnext(double n, int type) { + if (n <= 0) + __testlib_fail("random_t::wnext(double n, int type): n must be positive"); + + if (abs(type) < random_t::lim) { + double result = next(); + + for (int i = 0; i < +type; i++) + result = __testlib_max(result, next()); + + for (int i = 0; i < -type; i++) + result = __testlib_min(result, next()); + + return n * result; + } else { + double p; + + if (type > 0) + p = std::pow(next() + 0.0, 1.0 / (type + 1)); + else + p = 1 - std::pow(next() + 0.0, 1.0 / (-type + 1)); + + return __testlib_crop(n * p, 0.0, n); + } + } + + /* Returns value in [0, 1). See wnext(int, int). It uses the same algorithms. */ + double wnext(int type) { + return wnext(1.0, type); + } + + /* See wnext(int, int). It uses the same algorithms. */ + unsigned int wnext(unsigned int n, int type) { + if (n >= INT_MAX) + __testlib_fail("random_t::wnext(unsigned int n, int type): n must be less INT_MAX"); + return (unsigned int) wnext(int(n), type); + } + + /* See wnext(int, int). It uses the same algorithms. */ + unsigned long long wnext(unsigned long long n, int type) { + if (n >= (unsigned long long) (__TESTLIB_LONGLONG_MAX)) + __testlib_fail("random_t::wnext(unsigned long long n, int type): n must be less LONGLONG_MAX"); + + return (unsigned long long) wnext((long long) (n), type); + } + + /* See wnext(int, int). It uses the same algorithms. */ + long wnext(long n, int type) { + return (long) wnext((long long) (n), type); + } + + /* See wnext(int, int). It uses the same algorithms. */ + unsigned long wnext(unsigned long n, int type) { + if (n >= (unsigned long) (LONG_MAX)) + __testlib_fail("random_t::wnext(unsigned long n, int type): n must be less LONG_MAX"); + + return (unsigned long) wnext((unsigned long long) (n), type); + } + + /* Returns weighted random value in range [from, to]. */ + int wnext(int from, int to, int type) { + if (from > to) + __testlib_fail("random_t::wnext(int from, int to, int type): from can't not exceed to"); + return wnext(to - from + 1, type) + from; + } + + /* Returns weighted random value in range [from, to]. */ + int wnext(unsigned int from, unsigned int to, int type) { + if (from > to) + __testlib_fail("random_t::wnext(unsigned int from, unsigned int to, int type): from can't not exceed to"); + return int(wnext(to - from + 1, type) + from); + } + + /* Returns weighted random value in range [from, to]. */ + long long wnext(long long from, long long to, int type) { + if (from > to) + __testlib_fail("random_t::wnext(long long from, long long to, int type): from can't not exceed to"); + return wnext(to - from + 1, type) + from; + } + + /* Returns weighted random value in range [from, to]. */ + unsigned long long wnext(unsigned long long from, unsigned long long to, int type) { + if (from > to) + __testlib_fail( + "random_t::wnext(unsigned long long from, unsigned long long to, int type): from can't not exceed to"); + return wnext(to - from + 1, type) + from; + } + + /* Returns weighted random value in range [from, to]. */ + long wnext(long from, long to, int type) { + if (from > to) + __testlib_fail("random_t::wnext(long from, long to, int type): from can't not exceed to"); + return wnext(to - from + 1, type) + from; + } + + /* Returns weighted random value in range [from, to]. */ + unsigned long wnext(unsigned long from, unsigned long to, int type) { + if (from > to) + __testlib_fail("random_t::wnext(unsigned long from, unsigned long to, int type): from can't not exceed to"); + return wnext(to - from + 1, type) + from; + } + + /* Returns weighted random double value in range [from, to). */ + double wnext(double from, double to, int type) { + if (from >= to) + __testlib_fail("random_t::wnext(double from, double to, int type): from should be strictly less than to"); + return wnext(to - from, type) + from; + } + + /* Returns weighted random element from container. */ + template + typename Container::value_type wany(const Container &c, int type) { + int size = int(c.size()); + if (size <= 0) + __testlib_fail("random_t::wany(const Container& c, int type): c.size() must be positive"); + typename Container::const_iterator it = c.begin(); + std::advance(it, wnext(size, type)); + return *it; + } + + /* Returns weighted random element from iterator range. */ + template + typename Iter::value_type wany(const Iter &begin, const Iter &end, int type) { + int size = static_cast(std::distance(begin, end)); + if (size <= 0) + __testlib_fail( + "random_t::any(const Iter& begin, const Iter& end, int type): range must have positive length"); + Iter it = begin; + std::advance(it, wnext(size, type)); + return *it; + } + + /* Returns random permutation of the given size (values are between `first` and `first`+size-1)*/ + template + std::vector perm(T size, E first) { + if (size < 0) + __testlib_fail("random_t::perm(T size, E first = 0): size must non-negative"); + else if (size == 0) + return std::vector(); + std::vector p(size); + E current = first; + for (T i = 0; i < size; i++) + p[i] = current++; + if (size > 1) + for (T i = 1; i < size; i++) + std::swap(p[i], p[next(i + 1)]); + return p; + } + + /* Returns random permutation of the given size (values are between 0 and size-1)*/ + template + std::vector perm(T size) { + return perm(size, T(0)); + } + + /* Returns `size` unordered (unsorted) distinct numbers between `from` and `to`. */ + template + std::vector distinct(int size, T from, T to) { + std::vector result; + if (size == 0) + return result; + + if (from > to) + __testlib_fail("random_t::distinct expected from <= to"); + + if (size < 0) + __testlib_fail("random_t::distinct expected size >= 0"); + + uint64_t n = to - from + 1; + if (uint64_t(size) > n) + __testlib_fail("random_t::distinct expected size <= to - from + 1"); + + double expected = 0.0; + for (int i = 1; i <= size; i++) + expected += double(n) / double(n - i + 1); + + if (expected < double(n)) { + std::set vals; + while (int(vals.size()) < size) { + T x = T(next(from, to)); + if (vals.insert(x).second) + result.push_back(x); + } + } else { + if (n > 1000000000) + __testlib_fail("random_t::distinct here expected to - from + 1 <= 1000000000"); + std::vector p(perm(int(n), from)); + result.insert(result.end(), p.begin(), p.begin() + size); + } + + return result; + } + + /* Returns `size` unordered (unsorted) distinct numbers between `0` and `upper`-1. */ + template + std::vector distinct(int size, T upper) { + if (size < 0) + __testlib_fail("random_t::distinct expected size >= 0"); + if (size == 0) + return std::vector(); + + if (upper <= 0) + __testlib_fail("random_t::distinct expected upper > 0"); + if (size > upper) + __testlib_fail("random_t::distinct expected size <= upper"); + + return distinct(size, T(0), upper - 1); + } + + /* Returns random (unsorted) partition which is a representation of sum as a sum of integers not less than min_part. */ + template + std::vector partition(int size, T sum, T min_part) { + if (size < 0) + __testlib_fail("random_t::partition: size < 0"); + if (size == 0 && sum != 0) + __testlib_fail("random_t::partition: size == 0 && sum != 0"); + if (min_part * size > sum) + __testlib_fail("random_t::partition: min_part * size > sum"); + if (size == 0 && sum == 0) + return std::vector(); + + T sum_ = sum; + sum -= min_part * size; + + std::vector septums(size); + std::vector d = distinct(size - 1, T(1), T(sum + size - 1)); + for (int i = 0; i + 1 < size; i++) + septums[i + 1] = d[i]; + sort(septums.begin(), septums.end()); + + std::vector result(size); + for (int i = 0; i + 1 < size; i++) + result[i] = septums[i + 1] - septums[i] - 1; + result[size - 1] = sum + size - 1 - septums.back(); + + for (std::size_t i = 0; i < result.size(); i++) + result[i] += min_part; + + T result_sum = 0; + for (std::size_t i = 0; i < result.size(); i++) + result_sum += result[i]; + if (result_sum != sum_) + __testlib_fail("random_t::partition: partition sum is expected to be the given sum"); + + if (*std::min_element(result.begin(), result.end()) < min_part) + __testlib_fail("random_t::partition: partition min is expected to be no less than the given min_part"); + + if (int(result.size()) != size || result.size() != (size_t) size) + __testlib_fail("random_t::partition: partition size is expected to be equal to the given size"); + + return result; + } + + /* Returns random (unsorted) partition which is a representation of sum as a sum of positive integers. */ + template + std::vector partition(int size, T sum) { + return partition(size, sum, T(1)); + } +}; + +const int random_t::lim = 25; +const unsigned long long random_t::multiplier = 0x5DEECE66DLL; +const unsigned long long random_t::addend = 0xBLL; +const unsigned long long random_t::mask = (1LL << 48) - 1; +int random_t::version = -1; + +/* Pattern implementation */ +bool pattern::matches(const std::string &s) const { + return matches(s, 0); +} + +static bool __pattern_isSlash(const std::string &s, size_t pos) { + return s[pos] == '\\'; +} + +#ifdef __GNUC__ +__attribute__((pure)) +#endif +static bool __pattern_isCommandChar(const std::string &s, size_t pos, char value) { + if (pos >= s.length()) + return false; + + int slashes = 0; + + int before = int(pos) - 1; + while (before >= 0 && s[before] == '\\') + before--, slashes++; + + return slashes % 2 == 0 && s[pos] == value; +} + +static char __pattern_getChar(const std::string &s, size_t &pos) { + if (__pattern_isSlash(s, pos)) + pos += 2; + else + pos++; + + return s[pos - 1]; +} + +#ifdef __GNUC__ +__attribute__((pure)) +#endif +static int __pattern_greedyMatch(const std::string &s, size_t pos, const std::vector chars) { + int result = 0; + + while (pos < s.length()) { + char c = s[pos++]; + if (!std::binary_search(chars.begin(), chars.end(), c)) + break; + else + result++; + } + + return result; +} + +std::string pattern::src() const { + return s; +} + +bool pattern::matches(const std::string &s, size_t pos) const { + std::string result; + + if (to > 0) { + int size = __pattern_greedyMatch(s, pos, chars); + if (size < from) + return false; + if (size > to) + size = to; + pos += size; + } + + if (children.size() > 0) { + for (size_t child = 0; child < children.size(); child++) + if (children[child].matches(s, pos)) + return true; + return false; + } else + return pos == s.length(); +} + +std::string pattern::next(random_t &rnd) const { + std::string result; + result.reserve(20); + + if (to == INT_MAX) + __testlib_fail("pattern::next(random_t& rnd): can't process character '*' for generation"); + + if (to > 0) { + int count = rnd.next(to - from + 1) + from; + for (int i = 0; i < count; i++) + result += chars[rnd.next(int(chars.size()))]; + } + + if (children.size() > 0) { + int child = rnd.next(int(children.size())); + result += children[child].next(rnd); + } + + return result; +} + +static void __pattern_scanCounts(const std::string &s, size_t &pos, int &from, int &to) { + if (pos >= s.length()) { + from = to = 1; + return; + } + + if (__pattern_isCommandChar(s, pos, '{')) { + std::vector parts; + std::string part; + + pos++; + + while (pos < s.length() && !__pattern_isCommandChar(s, pos, '}')) { + if (__pattern_isCommandChar(s, pos, ',')) + parts.push_back(part), part = "", pos++; + else + part += __pattern_getChar(s, pos); + } + + if (part != "") + parts.push_back(part); + + if (!__pattern_isCommandChar(s, pos, '}')) + __testlib_fail("pattern: Illegal pattern (or part) \"" + s + "\""); + + pos++; + + if (parts.size() < 1 || parts.size() > 2) + __testlib_fail("pattern: Illegal pattern (or part) \"" + s + "\""); + + std::vector numbers; + + for (size_t i = 0; i < parts.size(); i++) { + if (parts[i].length() == 0) + __testlib_fail("pattern: Illegal pattern (or part) \"" + s + "\""); + int number; +#ifdef _MSC_VER + if (sscanf_s(parts[i].c_str(), "%d", &number) != 1) +#else + if (std::sscanf(parts[i].c_str(), "%d", &number) != 1) +#endif + __testlib_fail("pattern: Illegal pattern (or part) \"" + s + "\""); + numbers.push_back(number); + } + + if (numbers.size() == 1) + from = to = numbers[0]; + else + from = numbers[0], to = numbers[1]; + + if (from > to) + __testlib_fail("pattern: Illegal pattern (or part) \"" + s + "\""); + } else { + if (__pattern_isCommandChar(s, pos, '?')) { + from = 0, to = 1, pos++; + return; + } + + if (__pattern_isCommandChar(s, pos, '*')) { + from = 0, to = INT_MAX, pos++; + return; + } + + if (__pattern_isCommandChar(s, pos, '+')) { + from = 1, to = INT_MAX, pos++; + return; + } + + from = to = 1; + } +} + +static std::vector __pattern_scanCharSet(const std::string &s, size_t &pos) { + if (pos >= s.length()) + __testlib_fail("pattern: Illegal pattern (or part) \"" + s + "\""); + + std::vector result; + + if (__pattern_isCommandChar(s, pos, '[')) { + pos++; + bool negative = __pattern_isCommandChar(s, pos, '^'); + if (negative) + pos++; + + char prev = 0; + + while (pos < s.length() && !__pattern_isCommandChar(s, pos, ']')) { + if (__pattern_isCommandChar(s, pos, '-') && prev != 0) { + pos++; + + if (pos + 1 == s.length() || __pattern_isCommandChar(s, pos, ']')) { + result.push_back(prev); + prev = '-'; + continue; + } + + char next = __pattern_getChar(s, pos); + if (prev > next) + __testlib_fail("pattern: Illegal pattern (or part) \"" + s + "\""); + + for (char c = prev; c != next; c++) + result.push_back(c); + result.push_back(next); + + prev = 0; + } else { + if (prev != 0) + result.push_back(prev); + prev = __pattern_getChar(s, pos); + } + } + + if (prev != 0) + result.push_back(prev); + + if (!__pattern_isCommandChar(s, pos, ']')) + __testlib_fail("pattern: Illegal pattern (or part) \"" + s + "\""); + + pos++; + + if (negative) { + std::sort(result.begin(), result.end()); + std::vector actuals; + for (int code = 0; code < 255; code++) { + char c = char(code); + if (!std::binary_search(result.begin(), result.end(), c)) + actuals.push_back(c); + } + result = actuals; + } + + std::sort(result.begin(), result.end()); + } else + result.push_back(__pattern_getChar(s, pos)); + + return result; +} + +pattern::pattern(std::string s) : s(s), from(0), to(0) { + std::string t; + for (size_t i = 0; i < s.length(); i++) + if (!__pattern_isCommandChar(s, i, ' ')) + t += s[i]; + s = t; + + int opened = 0; + int firstClose = -1; + std::vector seps; + + for (size_t i = 0; i < s.length(); i++) { + if (__pattern_isCommandChar(s, i, '(')) { + opened++; + continue; + } + + if (__pattern_isCommandChar(s, i, ')')) { + opened--; + if (opened == 0 && firstClose == -1) + firstClose = int(i); + continue; + } + + if (opened < 0) + __testlib_fail("pattern: Illegal pattern (or part) \"" + s + "\""); + + if (__pattern_isCommandChar(s, i, '|') && opened == 0) + seps.push_back(int(i)); + } + + if (opened != 0) + __testlib_fail("pattern: Illegal pattern (or part) \"" + s + "\""); + + if (seps.size() == 0 && firstClose + 1 == (int) s.length() + && __pattern_isCommandChar(s, 0, '(') && __pattern_isCommandChar(s, s.length() - 1, ')')) { + children.push_back(pattern(s.substr(1, s.length() - 2))); + } else { + if (seps.size() > 0) { + seps.push_back(int(s.length())); + int last = 0; + + for (size_t i = 0; i < seps.size(); i++) { + children.push_back(pattern(s.substr(last, seps[i] - last))); + last = seps[i] + 1; + } + } else { + size_t pos = 0; + chars = __pattern_scanCharSet(s, pos); + __pattern_scanCounts(s, pos, from, to); + if (pos < s.length()) + children.push_back(pattern(s.substr(pos))); + } + } +} + +/* End of pattern implementation */ + +template +inline bool isEof(C c) { + return c == EOFC; +} + +template +inline bool isEoln(C c) { + return (c == LF || c == CR); +} + +template +inline bool isBlanks(C c) { + return (c == LF || c == CR || c == SPACE || c == TAB); +} + +inline std::string trim(const std::string &s) { + if (s.empty()) + return s; + + int left = 0; + while (left < int(s.length()) && isBlanks(s[left])) + left++; + if (left >= int(s.length())) + return ""; + + int right = int(s.length()) - 1; + while (right >= 0 && isBlanks(s[right])) + right--; + if (right < 0) + return ""; + + return s.substr(left, right - left + 1); +} + +enum TMode { + _input, _output, _answer +}; + +/* Outcomes 6-15 are reserved for future use. */ +enum TResult { + _ok = 0, + _wa = 1, + _pe = 2, + _fail = 3, + _dirt = 4, + _points = 5, + _unexpected_eof = 8, + _partially = 16 +}; + +enum TTestlibMode { + _unknown, _checker, _validator, _generator, _interactor, _scorer +}; + +#define _pc(exitCode) (TResult(_partially + (exitCode))) + +/* Outcomes 6-15 are reserved for future use. */ +const std::string outcomes[] = { + "accepted", + "wrong-answer", + "presentation-error", + "fail", + "fail", +#ifndef PCMS2 + "points", +#else + "relative-scoring", +#endif + "reserved", + "reserved", + "unexpected-eof", + "reserved", + "reserved", + "reserved", + "reserved", + "reserved", + "reserved", + "reserved", + "partially-correct" +}; + +class InputStreamReader { +public: + virtual void setTestCase(int testCase) = 0; + + virtual std::vector getReadChars() = 0; + + virtual int curChar() = 0; + + virtual int nextChar() = 0; + + virtual void skipChar() = 0; + + virtual void unreadChar(int c) = 0; + + virtual std::string getName() = 0; + + virtual bool eof() = 0; + + virtual void close() = 0; + + virtual int getLine() = 0; + + virtual ~InputStreamReader() = 0; +}; + +InputStreamReader::~InputStreamReader() { + // No operations. +} + +class StringInputStreamReader : public InputStreamReader { +private: + std::string s; + size_t pos; + +public: + StringInputStreamReader(const std::string &content) : s(content), pos(0) { + // No operations. + } + + void setTestCase(int) { + __testlib_fail("setTestCase not implemented in StringInputStreamReader"); + } + + std::vector getReadChars() { + __testlib_fail("getReadChars not implemented in StringInputStreamReader"); + } + + int curChar() { + if (pos >= s.length()) + return EOFC; + else + return s[pos]; + } + + int nextChar() { + if (pos >= s.length()) { + pos++; + return EOFC; + } else + return s[pos++]; + } + + void skipChar() { + pos++; + } + + void unreadChar(int c) { + if (pos == 0) + __testlib_fail("StringInputStreamReader::unreadChar(int): pos == 0."); + pos--; + if (pos < s.length()) + s[pos] = char(c); + } + + std::string getName() { + return __testlib_part(s); + } + + int getLine() { + return -1; + } + + bool eof() { + return pos >= s.length(); + } + + void close() { + // No operations. + } +}; + +class FileInputStreamReader : public InputStreamReader { +private: + std::FILE *file; + std::string name; + int line; + std::vector undoChars; + std::vector readChars; + std::vector undoReadChars; + + inline int postprocessGetc(int getcResult) { + if (getcResult != EOF) + return getcResult; + else + return EOFC; + } + + int getc(FILE *file) { + int c; + int rc; + + if (undoChars.empty()) { + c = rc = ::getc(file); + } else { + c = undoChars.back(); + undoChars.pop_back(); + rc = undoReadChars.back(); + undoReadChars.pop_back(); + } + + if (c == LF) + line++; + + readChars.push_back(rc); + return c; + } + + int ungetc(int c/*, FILE* file*/) { + if (!readChars.empty()) { + undoReadChars.push_back(readChars.back()); + readChars.pop_back(); + } + if (c == LF) + line--; + undoChars.push_back(c); + return c; + } + +public: + FileInputStreamReader(std::FILE *file, const std::string &name) : file(file), name(name), line(1) { + // No operations. + } + + void setTestCase(int testCase) { + if (testCase < 0 || testCase > __TESTLIB_MAX_TEST_CASE) + __testlib_fail(testlib_format_("testCase expected fit in [1,%d], but %d doesn't", __TESTLIB_MAX_TEST_CASE, testCase)); + readChars.push_back(testCase + 256); + } + + std::vector getReadChars() { + return readChars; + } + + int curChar() { + if (feof(file)) + return EOFC; + else { + int c = getc(file); + ungetc(c/*, file*/); + return postprocessGetc(c); + } + } + + int nextChar() { + if (feof(file)) + return EOFC; + else + return postprocessGetc(getc(file)); + } + + void skipChar() { + getc(file); + } + + void unreadChar(int c) { + ungetc(c/*, file*/); + } + + std::string getName() { + return name; + } + + int getLine() { + return line; + } + + bool eof() { + if (NULL == file || feof(file)) + return true; + else { + int c = nextChar(); + if (c == EOFC || (c == EOF && feof(file))) + return true; + unreadChar(c); + return false; + } + } + + void close() { + if (NULL != file) { + fclose(file); + file = NULL; + } + } +}; + +class BufferedFileInputStreamReader : public InputStreamReader { +private: + static const size_t BUFFER_SIZE; + static const size_t MAX_UNREAD_COUNT; + + std::FILE *file; + std::string name; + int line; + + char *buffer; + bool *isEof; + int bufferPos; + size_t bufferSize; + + bool refill() { + if (NULL == file) + __testlib_fail("BufferedFileInputStreamReader: file == NULL (" + getName() + ")"); + + if (bufferPos >= int(bufferSize)) { + size_t readSize = fread( + buffer + MAX_UNREAD_COUNT, + 1, + BUFFER_SIZE - MAX_UNREAD_COUNT, + file + ); + + if (readSize < BUFFER_SIZE - MAX_UNREAD_COUNT + && ferror(file)) + __testlib_fail("BufferedFileInputStreamReader: unable to read (" + getName() + ")"); + + bufferSize = MAX_UNREAD_COUNT + readSize; + bufferPos = int(MAX_UNREAD_COUNT); + std::memset(isEof + MAX_UNREAD_COUNT, 0, sizeof(isEof[0]) * readSize); + + return readSize > 0; + } else + return true; + } + + char increment() { + char c; + if ((c = buffer[bufferPos++]) == LF) + line++; + return c; + } + +public: + BufferedFileInputStreamReader(std::FILE *file, const std::string &name) : file(file), name(name), line(1) { + buffer = new char[BUFFER_SIZE]; + isEof = new bool[BUFFER_SIZE]; + bufferSize = MAX_UNREAD_COUNT; + bufferPos = int(MAX_UNREAD_COUNT); + } + + ~BufferedFileInputStreamReader() { + if (NULL != buffer) { + delete[] buffer; + buffer = NULL; + } + if (NULL != isEof) { + delete[] isEof; + isEof = NULL; + } + } + + void setTestCase(int) { + __testlib_fail("setTestCase not implemented in BufferedFileInputStreamReader"); + } + + std::vector getReadChars() { + __testlib_fail("getReadChars not implemented in BufferedFileInputStreamReader"); + } + + int curChar() { + if (!refill()) + return EOFC; + + return isEof[bufferPos] ? EOFC : buffer[bufferPos]; + } + + int nextChar() { + if (!refill()) + return EOFC; + + return isEof[bufferPos] ? EOFC : increment(); + } + + void skipChar() { + increment(); + } + + void unreadChar(int c) { + bufferPos--; + if (bufferPos < 0) + __testlib_fail("BufferedFileInputStreamReader::unreadChar(int): bufferPos < 0"); + isEof[bufferPos] = (c == EOFC); + buffer[bufferPos] = char(c); + if (c == LF) + line--; + } + + std::string getName() { + return name; + } + + int getLine() { + return line; + } + + bool eof() { + return !refill() || EOFC == curChar(); + } + + void close() { + if (NULL != file) { + fclose(file); + file = NULL; + } + } +}; + +const size_t BufferedFileInputStreamReader::BUFFER_SIZE = 2000000; +const size_t BufferedFileInputStreamReader::MAX_UNREAD_COUNT = BufferedFileInputStreamReader::BUFFER_SIZE / 2; + +/* + * Streams to be used for reading data in checkers or validators. + * Each read*() method moves pointer to the next character after the + * read value. + */ +struct InStream { + /* Do not use them. */ + InStream(); + + ~InStream(); + + /* Wrap std::string with InStream. */ + InStream(const InStream &baseStream, std::string content); + + InputStreamReader *reader; + int lastLine; + + std::string name; + TMode mode; + bool opened; + bool stdfile; + bool strict; + + int wordReserveSize; + std::string _tmpReadToken; + + int readManyIteration; + size_t maxFileSize; + size_t maxTokenLength; + size_t maxMessageLength; + + void init(std::string fileName, TMode mode); + + void init(std::FILE *f, TMode mode); + + void setTestCase(int testCase); + std::vector getReadChars(); + + /* Moves stream pointer to the first non-white-space character or EOF. */ + void skipBlanks(); + + /* Returns current character in the stream. Doesn't remove it from stream. */ + char curChar(); + + /* Moves stream pointer one character forward. */ + void skipChar(); + + /* Returns current character and moves pointer one character forward. */ + char nextChar(); + + /* Returns current character and moves pointer one character forward. */ + char readChar(); + + /* As "readChar()" but ensures that the result is equal to given parameter. */ + char readChar(char c); + + /* As "readChar()" but ensures that the result is equal to the space (code=32). */ + char readSpace(); + + /* Puts back the character into the stream. */ + void unreadChar(char c); + + /* Reopens stream, you should not use it. */ + void reset(std::FILE *file = NULL); + + /* Checks that current position is EOF. If not it doesn't move stream pointer. */ + bool eof(); + + /* Moves pointer to the first non-white-space character and calls "eof()". */ + bool seekEof(); + + /* + * Checks that current position contains EOLN. + * If not it doesn't move stream pointer. + * In strict mode expects "#13#10" for windows or "#10" for other platforms. + */ + bool eoln(); + + /* Moves pointer to the first non-space and non-tab character and calls "eoln()". */ + bool seekEoln(); + + /* Moves stream pointer to the first character of the next line (if exists). */ + void nextLine(); + + /* + * Reads new token. Ignores white-spaces into the non-strict mode + * (strict mode is used in validators usually). + */ + std::string readWord(); + + /* The same as "readWord()", it is preferred to use "readToken()". */ + std::string readToken(); + + /* The same as "readWord()", but ensures that token matches to given pattern. */ + std::string readWord(const std::string &ptrn, const std::string &variableName = ""); + + std::string readWord(const pattern &p, const std::string &variableName = ""); + + std::vector + readWords(int size, const std::string &ptrn, const std::string &variablesName = "", int indexBase = 1); + + std::vector + readWords(int size, const pattern &p, const std::string &variablesName = "", int indexBase = 1); + + std::vector readWords(int size, int indexBase = 1); + + /* The same as "readToken()", but ensures that token matches to given pattern. */ + std::string readToken(const std::string &ptrn, const std::string &variableName = ""); + + std::string readToken(const pattern &p, const std::string &variableName = ""); + + std::vector + readTokens(int size, const std::string &ptrn, const std::string &variablesName = "", int indexBase = 1); + + std::vector + readTokens(int size, const pattern &p, const std::string &variablesName = "", int indexBase = 1); + + std::vector readTokens(int size, int indexBase = 1); + + void readWordTo(std::string &result); + + void readWordTo(std::string &result, const pattern &p, const std::string &variableName = ""); + + void readWordTo(std::string &result, const std::string &ptrn, const std::string &variableName = ""); + + void readTokenTo(std::string &result); + + void readTokenTo(std::string &result, const pattern &p, const std::string &variableName = ""); + + void readTokenTo(std::string &result, const std::string &ptrn, const std::string &variableName = ""); + + /* + * Reads new long long value. Ignores white-spaces into the non-strict mode + * (strict mode is used in validators usually). + */ + long long readLong(); + + unsigned long long readUnsignedLong(); + + /* + * Reads new int. Ignores white-spaces into the non-strict mode + * (strict mode is used in validators usually). + */ + int readInteger(); + + /* + * Reads new int. Ignores white-spaces into the non-strict mode + * (strict mode is used in validators usually). + */ + int readInt(); + + /* As "readLong()" but ensures that value in the range [minv,maxv]. */ + long long readLong(long long minv, long long maxv, const std::string &variableName = ""); + + /* Reads space-separated sequence of long longs. */ + std::vector + readLongs(int size, long long minv, long long maxv, const std::string &variablesName = "", int indexBase = 1); + + /* Reads space-separated sequence of long longs. */ + std::vector readLongs(int size, int indexBase = 1); + + unsigned long long + readUnsignedLong(unsigned long long minv, unsigned long long maxv, const std::string &variableName = ""); + + std::vector + readUnsignedLongs(int size, unsigned long long minv, unsigned long long maxv, const std::string &variablesName = "", + int indexBase = 1); + + std::vector readUnsignedLongs(int size, int indexBase = 1); + + unsigned long long readLong(unsigned long long minv, unsigned long long maxv, const std::string &variableName = ""); + + std::vector + readLongs(int size, unsigned long long minv, unsigned long long maxv, const std::string &variablesName = "", + int indexBase = 1); + + /* As "readInteger()" but ensures that value in the range [minv,maxv]. */ + int readInteger(int minv, int maxv, const std::string &variableName = ""); + + /* As "readInt()" but ensures that value in the range [minv,maxv]. */ + int readInt(int minv, int maxv, const std::string &variableName = ""); + + /* Reads space-separated sequence of integers. */ + std::vector + readIntegers(int size, int minv, int maxv, const std::string &variablesName = "", int indexBase = 1); + + /* Reads space-separated sequence of integers. */ + std::vector readIntegers(int size, int indexBase = 1); + + /* Reads space-separated sequence of integers. */ + std::vector readInts(int size, int minv, int maxv, const std::string &variablesName = "", int indexBase = 1); + + /* Reads space-separated sequence of integers. */ + std::vector readInts(int size, int indexBase = 1); + + /* + * Reads new double. Ignores white-spaces into the non-strict mode + * (strict mode is used in validators usually). + */ + double readReal(); + + /* + * Reads new double. Ignores white-spaces into the non-strict mode + * (strict mode is used in validators usually). + */ + double readDouble(); + + /* As "readReal()" but ensures that value in the range [minv,maxv]. */ + double readReal(double minv, double maxv, const std::string &variableName = ""); + + std::vector + readReals(int size, double minv, double maxv, const std::string &variablesName = "", int indexBase = 1); + + std::vector readReals(int size, int indexBase = 1); + + /* As "readDouble()" but ensures that value in the range [minv,maxv]. */ + double readDouble(double minv, double maxv, const std::string &variableName = ""); + + std::vector + readDoubles(int size, double minv, double maxv, const std::string &variablesName = "", int indexBase = 1); + + std::vector readDoubles(int size, int indexBase = 1); + + /* + * As "readReal()" but ensures that value in the range [minv,maxv] and + * number of digit after the decimal point is in range [minAfterPointDigitCount,maxAfterPointDigitCount] + * and number is in the form "[-]digit(s)[.digit(s)]". + */ + double readStrictReal(double minv, double maxv, + int minAfterPointDigitCount, int maxAfterPointDigitCount, + const std::string &variableName = ""); + + std::vector readStrictReals(int size, double minv, double maxv, + int minAfterPointDigitCount, int maxAfterPointDigitCount, + const std::string &variablesName = "", int indexBase = 1); + + /* + * As "readDouble()" but ensures that value in the range [minv,maxv] and + * number of digit after the decimal point is in range [minAfterPointDigitCount,maxAfterPointDigitCount] + * and number is in the form "[-]digit(s)[.digit(s)]". + */ + double readStrictDouble(double minv, double maxv, + int minAfterPointDigitCount, int maxAfterPointDigitCount, + const std::string &variableName = ""); + + std::vector readStrictDoubles(int size, double minv, double maxv, + int minAfterPointDigitCount, int maxAfterPointDigitCount, + const std::string &variablesName = "", int indexBase = 1); + + /* As readLine(). */ + std::string readString(); + + /* Read many lines. */ + std::vector readStrings(int size, int indexBase = 1); + + /* See readLine(). */ + void readStringTo(std::string &result); + + /* The same as "readLine()/readString()", but ensures that line matches to the given pattern. */ + std::string readString(const pattern &p, const std::string &variableName = ""); + + /* The same as "readLine()/readString()", but ensures that line matches to the given pattern. */ + std::string readString(const std::string &ptrn, const std::string &variableName = ""); + + /* Read many lines. */ + std::vector + readStrings(int size, const pattern &p, const std::string &variableName = "", int indexBase = 1); + + /* Read many lines. */ + std::vector + readStrings(int size, const std::string &ptrn, const std::string &variableName = "", int indexBase = 1); + + /* The same as "readLine()/readString()", but ensures that line matches to the given pattern. */ + void readStringTo(std::string &result, const pattern &p, const std::string &variableName = ""); + + /* The same as "readLine()/readString()", but ensures that line matches to the given pattern. */ + void readStringTo(std::string &result, const std::string &ptrn, const std::string &variableName = ""); + + /* + * Reads line from the current position to EOLN or EOF. Moves stream pointer to + * the first character of the new line (if possible). + */ + std::string readLine(); + + /* Read many lines. */ + std::vector readLines(int size, int indexBase = 1); + + /* See readLine(). */ + void readLineTo(std::string &result); + + /* The same as "readLine()", but ensures that line matches to the given pattern. */ + std::string readLine(const pattern &p, const std::string &variableName = ""); + + /* The same as "readLine()", but ensures that line matches to the given pattern. */ + std::string readLine(const std::string &ptrn, const std::string &variableName = ""); + + /* Read many lines. */ + std::vector + readLines(int size, const pattern &p, const std::string &variableName = "", int indexBase = 1); + + /* Read many lines. */ + std::vector + readLines(int size, const std::string &ptrn, const std::string &variableName = "", int indexBase = 1); + + /* The same as "readLine()", but ensures that line matches to the given pattern. */ + void readLineTo(std::string &result, const pattern &p, const std::string &variableName = ""); + + /* The same as "readLine()", but ensures that line matches to the given pattern. */ + void readLineTo(std::string &result, const std::string &ptrn, const std::string &variableName = ""); + + /* Reads EOLN or fails. Use it in validators. Calls "eoln()" method internally. */ + void readEoln(); + + /* Reads EOF or fails. Use it in validators. Calls "eof()" method internally. */ + void readEof(); + + /* + * Quit-functions aborts program with and : + * input/answer streams replace any result to FAIL. + */ + NORETURN void quit(TResult result, const char *msg); + /* + * Quit-functions aborts program with and : + * input/answer streams replace any result to FAIL. + */ + NORETURN void quitf(TResult result, const char *msg, ...); + + /* + * Quit-functions aborts program with and : + * input/answer streams replace any result to FAIL. + */ + void quitif(bool condition, TResult result, const char *msg, ...); + /* + * Quit-functions aborts program with and : + * input/answer streams replace any result to FAIL. + */ + NORETURN void quits(TResult result, std::string msg); + + /* + * Checks condition and aborts a program if condition is false. + * Returns _wa for ouf and _fail on any other streams. + */ +#ifdef __GNUC__ + __attribute__ ((format (printf, 3, 4))) +#endif + void ensuref(bool cond, const char *format, ...); + + void __testlib_ensure(bool cond, std::string message); + + void close(); + + const static int NO_INDEX = INT_MAX; + const static char OPEN_BRACKET = char(11); + const static char CLOSE_BRACKET = char(17); + + const static WORD LightGray = 0x07; + const static WORD LightRed = 0x0c; + const static WORD LightCyan = 0x0b; + const static WORD LightGreen = 0x0a; + const static WORD LightYellow = 0x0e; + + static void textColor(WORD color); + + static void quitscr(WORD color, const char *msg); + + static void quitscrS(WORD color, std::string msg); + + void xmlSafeWrite(std::FILE *file, const char *msg); + + /* Skips UTF-8 Byte Order Mark. */ + void skipBom(); + +private: + InStream(const InStream &); + + InStream &operator=(const InStream &); +}; + +InStream inf; +InStream ouf; +InStream ans; +bool appesMode; +std::string appesModeEncoding = "windows-1251"; +std::string resultName; +std::string checkerName = "untitled checker"; +random_t rnd; +TTestlibMode testlibMode = _unknown; +double __testlib_points = std::numeric_limits::infinity(); + +const size_t VALIDATOR_MAX_VARIABLE_COUNT = 255; + +struct ValidatorBoundsHit { + static const double EPS; + bool minHit; + bool maxHit; + + ValidatorBoundsHit(bool minHit = false, bool maxHit = false) : minHit(minHit), maxHit(maxHit) { + }; + + ValidatorBoundsHit merge(const ValidatorBoundsHit &validatorBoundsHit, bool ignoreMinBound, bool ignoreMaxBound) { + return ValidatorBoundsHit( + __testlib_max(minHit, validatorBoundsHit.minHit) || ignoreMinBound, + __testlib_max(maxHit, validatorBoundsHit.maxHit) || ignoreMaxBound + ); + } +}; + +struct ConstantBound { + std::string value; + bool broken; + + template + void adjust(T t) { + std::string t_string = std::to_string(t); + if (t_string.length() >= 32) { + broken = true; + value = ""; + } else { + if (!broken && value.empty()) + value = t_string; + if (!broken && value != t_string) { + broken = true; + value = ""; + } + } + } + + bool has_value() { + return !value.empty() && !broken && value.length() < 32; + } +}; + +struct ConstantBounds { + ConstantBound lowerBound; + ConstantBound upperBound; +}; + +const double ValidatorBoundsHit::EPS = 1E-12; + +class Validator { +private: + const static std::string TEST_MARKUP_HEADER; + const static std::string TEST_CASE_OPEN_TAG; + const static std::string TEST_CASE_CLOSE_TAG; + + bool _initialized; + std::string _testset; + std::string _group; + + std::string _testOverviewLogFileName; + std::string _testMarkupFileName; + int _testCase = -1; + std::string _testCaseFileName; + + std::map _boundsHitByVariableName; + std::map _constantBoundsByVariableName; + std::set _features; + std::set _hitFeatures; + std::set _variables; + + bool isVariableNameBoundsAnalyzable(const std::string &variableName) { + for (size_t i = 0; i < variableName.length(); i++) + if ((variableName[i] >= '0' && variableName[i] <= '9') || variableName[i] < ' ') + return false; + return true; + } + + bool isFeatureNameAnalyzable(const std::string &featureName) { + for (size_t i = 0; i < featureName.length(); i++) + if (featureName[i] < ' ') + return false; + return true; + } + +public: + Validator() : _initialized(false), _testset("tests"), _group() { + } + + void initialize() { + _initialized = true; + } + + std::string testset() const { + if (!_initialized) + __testlib_fail("Validator should be initialized with registerValidation(argc, argv) instead of registerValidation() to support validator.testset()"); + return _testset; + } + + std::string group() const { + if (!_initialized) + __testlib_fail("Validator should be initialized with registerValidation(argc, argv) instead of registerValidation() to support validator.group()"); + return _group; + } + + std::string testOverviewLogFileName() const { + return _testOverviewLogFileName; + } + + std::string testMarkupFileName() const { + return _testMarkupFileName; + } + + int testCase() const { + return _testCase; + } + + std::string testCaseFileName() const { + return _testCaseFileName; + } + + void setTestset(const char *const testset) { + _testset = testset; + } + + void setGroup(const char *const group) { + _group = group; + } + + void setTestOverviewLogFileName(const char *const testOverviewLogFileName) { + _testOverviewLogFileName = testOverviewLogFileName; + } + + void setTestMarkupFileName(const char *const testMarkupFileName) { + _testMarkupFileName = testMarkupFileName; + } + + void setTestCase(int testCase) { + _testCase = testCase; + } + + void setTestCaseFileName(const char *const testCaseFileName) { + _testCaseFileName = testCaseFileName; + } + + std::string prepVariableName(const std::string &variableName) { + if (variableName.length() >= 2 && variableName != "~~") { + if (variableName[0] == '~' && variableName.back() != '~') + return variableName.substr(1); + if (variableName[0] != '~' && variableName.back() == '~') + return variableName.substr(0, variableName.length() - 1); + if (variableName[0] == '~' && variableName.back() == '~') + return variableName.substr(1, variableName.length() - 2); + } + return variableName; + } + + bool ignoreMinBound(const std::string &variableName) { + return variableName.length() >= 2 && variableName != "~~" && variableName[0] == '~'; + } + + bool ignoreMaxBound(const std::string &variableName) { + return variableName.length() >= 2 && variableName != "~~" && variableName.back() == '~'; + } + + void addBoundsHit(const std::string &variableName, ValidatorBoundsHit boundsHit) { + if (isVariableNameBoundsAnalyzable(variableName) + && _boundsHitByVariableName.size() < VALIDATOR_MAX_VARIABLE_COUNT) { + std::string preparedVariableName = prepVariableName(variableName); + _boundsHitByVariableName[preparedVariableName] = boundsHit.merge(_boundsHitByVariableName[preparedVariableName], + ignoreMinBound(variableName), ignoreMaxBound(variableName)); + } + } + + void addVariable(const std::string &variableName) { + if (isVariableNameBoundsAnalyzable(variableName) + && _variables.size() < VALIDATOR_MAX_VARIABLE_COUNT) { + std::string preparedVariableName = prepVariableName(variableName); + _variables.insert(preparedVariableName); + } + } + + std::string getVariablesLog() { + std::string result; + for (const std::string &variableName: _variables) + result += "variable \"" + variableName + "\"\n"; + return result; + } + + template + void adjustConstantBounds(const std::string &variableName, T lower, T upper) { + if (isVariableNameBoundsAnalyzable(variableName) + && _constantBoundsByVariableName.size() < VALIDATOR_MAX_VARIABLE_COUNT) { + std::string preparedVariableName = prepVariableName(variableName); + _constantBoundsByVariableName[preparedVariableName].lowerBound.adjust(lower); + _constantBoundsByVariableName[preparedVariableName].upperBound.adjust(upper); + } + } + + std::string getBoundsHitLog() { + std::string result; + for (std::map::iterator i = _boundsHitByVariableName.begin(); + i != _boundsHitByVariableName.end(); + i++) { + result += "\"" + i->first + "\":"; + if (i->second.minHit) + result += " min-value-hit"; + if (i->second.maxHit) + result += " max-value-hit"; + result += "\n"; + } + return result; + } + + std::string getConstantBoundsLog() { + std::string result; + for (std::map::iterator i = _constantBoundsByVariableName.begin(); + i != _constantBoundsByVariableName.end(); + i++) { + if (i->second.lowerBound.has_value() || i->second.upperBound.has_value()) { + result += "constant-bounds \"" + i->first + "\":"; + if (i->second.lowerBound.has_value()) + result += " " + i->second.lowerBound.value; + else + result += " ?"; + if (i->second.upperBound.has_value()) + result += " " + i->second.upperBound.value; + else + result += " ?"; + result += "\n"; + } + } + return result; + } + + std::string getFeaturesLog() { + std::string result; + for (std::set::iterator i = _features.begin(); + i != _features.end(); + i++) { + result += "feature \"" + *i + "\":"; + if (_hitFeatures.count(*i)) + result += " hit"; + result += "\n"; + } + return result; + } + + void writeTestOverviewLog() { + if (!_testOverviewLogFileName.empty()) { + std::string fileName(_testOverviewLogFileName); + _testOverviewLogFileName = ""; + + FILE* f; + bool standard_file = false; + if (fileName == "stdout") + f = stdout, standard_file = true; + else if (fileName == "stderr") + f = stderr, standard_file = true; + else { + f = testlib_fopen_(fileName.c_str(), "wb"); + if (NULL == f) + __testlib_fail("Validator::writeTestOverviewLog: can't write test overview log to (" + fileName + ")"); + } + fprintf(f, "%s%s%s%s", + getBoundsHitLog().c_str(), + getFeaturesLog().c_str(), + getConstantBoundsLog().c_str(), + getVariablesLog().c_str()); + std::fflush(f); + if (!standard_file) + if (std::fclose(f)) + __testlib_fail("Validator::writeTestOverviewLog: can't close test overview log file (" + fileName + ")"); + } + } + + void writeTestMarkup() { + if (!_testMarkupFileName.empty()) { + std::vector readChars = inf.getReadChars(); + if (!readChars.empty()) { + std::string markup(TEST_MARKUP_HEADER); + for (size_t i = 0; i < readChars.size(); i++) { + int c = readChars[i]; + if (i + 1 == readChars.size() && c == -1) + continue; + if (c <= 256) { + char cc = char(c); + if (cc == '\\' || cc == '!') + markup += '\\'; + markup += cc; + } else { + markup += TEST_CASE_OPEN_TAG; + markup += toString(c - 256); + markup += TEST_CASE_CLOSE_TAG; + } + } + FILE* f; + bool standard_file = false; + if (_testMarkupFileName == "stdout") + f = stdout, standard_file = true; + else if (_testMarkupFileName == "stderr") + f = stderr, standard_file = true; + else { + f = testlib_fopen_(_testMarkupFileName.c_str(), "wb"); + if (NULL == f) + __testlib_fail("Validator::writeTestMarkup: can't write test markup to (" + _testMarkupFileName + ")"); + } + std::fprintf(f, "%s", markup.c_str()); + std::fflush(f); + if (!standard_file) + if (std::fclose(f)) + __testlib_fail("Validator::writeTestMarkup: can't close test markup file (" + _testCaseFileName + ")"); + } + } + } + + void writeTestCase() { + if (_testCase > 0) { + std::vector readChars = inf.getReadChars(); + if (!readChars.empty()) { + std::string content, testCaseContent; + bool matchedTestCase = false; + for (size_t i = 0; i < readChars.size(); i++) { + int c = readChars[i]; + if (i + 1 == readChars.size() && c == -1) + continue; + if (c <= 256) + content += char(c); + else { + if (matchedTestCase) { + testCaseContent = content; + matchedTestCase = false; + } + content = ""; + int testCase = c - 256; + if (testCase == _testCase) + matchedTestCase = true; + } + } + if (matchedTestCase) + testCaseContent = content; + + if (!testCaseContent.empty()) { + FILE* f; + bool standard_file = false; + if (_testCaseFileName.empty() || _testCaseFileName == "stdout") + f = stdout, standard_file = true; + else if (_testCaseFileName == "stderr") + f = stderr, standard_file = true; + else { + f = testlib_fopen_(_testCaseFileName.c_str(), "wb"); + if (NULL == f) + __testlib_fail("Validator::writeTestCase: can't write test case to (" + _testCaseFileName + ")"); + } + std::fprintf(f, "%s", testCaseContent.c_str()); + std::fflush(f); + if (!standard_file) + if (std::fclose(f)) + __testlib_fail("Validator::writeTestCase: can't close test case file (" + _testCaseFileName + ")"); + } + } + } + } + + void addFeature(const std::string &feature) { + if (_features.count(feature)) + __testlib_fail("Feature " + feature + " registered twice."); + if (!isFeatureNameAnalyzable(feature)) + __testlib_fail("Feature name '" + feature + "' contains restricted characters."); + + _features.insert(feature); + } + + void feature(const std::string &feature) { + if (!isFeatureNameAnalyzable(feature)) + __testlib_fail("Feature name '" + feature + "' contains restricted characters."); + + if (!_features.count(feature)) + __testlib_fail("Feature " + feature + " didn't registered via addFeature(feature)."); + + _hitFeatures.insert(feature); + } +} validator; + +const std::string Validator::TEST_MARKUP_HEADER = "MU\xF3\x01"; +const std::string Validator::TEST_CASE_OPEN_TAG = "!c"; +const std::string Validator::TEST_CASE_CLOSE_TAG = ";"; + +struct TestlibFinalizeGuard { + static bool alive; + static bool registered; + + int quitCount, readEofCount; + + TestlibFinalizeGuard() : quitCount(0), readEofCount(0) { + // No operations. + } + + ~TestlibFinalizeGuard() { + bool _alive = alive; + alive = false; + + if (_alive) { + if (testlibMode == _checker && quitCount == 0) + __testlib_fail("Checker must end with quit or quitf call."); + + if (testlibMode == _validator && readEofCount == 0 && quitCount == 0) + __testlib_fail("Validator must end with readEof call."); + + /* opts */ + autoEnsureNoUnusedOpts(); + + if (!registered) + __testlib_fail("Call register-function in the first line of the main (registerTestlibCmd or other similar)"); + } + + if (__testlib_exitCode == 0) { + validator.writeTestOverviewLog(); + validator.writeTestMarkup(); + validator.writeTestCase(); + } + } + +private: + /* opts */ + void autoEnsureNoUnusedOpts(); +}; + +bool TestlibFinalizeGuard::alive = true; +bool TestlibFinalizeGuard::registered = false; +extern TestlibFinalizeGuard testlibFinalizeGuard; + +/* + * Call it to disable checks on finalization. + */ +void disableFinalizeGuard() { + TestlibFinalizeGuard::alive = false; +} + +/* Interactor streams. + */ +std::fstream tout; + +/* implementation + */ + +InStream::InStream() { + reader = NULL; + lastLine = -1; + opened = false; + name = ""; + mode = _input; + strict = false; + stdfile = false; + wordReserveSize = 4; + readManyIteration = NO_INDEX; + maxFileSize = 128 * 1024 * 1024; // 128MB. + maxTokenLength = 32 * 1024 * 1024; // 32MB. + maxMessageLength = 32000; +} + +InStream::InStream(const InStream &baseStream, std::string content) { + reader = new StringInputStreamReader(content); + lastLine = -1; + opened = true; + strict = baseStream.strict; + stdfile = false; + mode = baseStream.mode; + name = "based on " + baseStream.name; + readManyIteration = NO_INDEX; + maxFileSize = 128 * 1024 * 1024; // 128MB. + maxTokenLength = 32 * 1024 * 1024; // 32MB. + maxMessageLength = 32000; +} + +InStream::~InStream() { + if (NULL != reader) { + reader->close(); + delete reader; + reader = NULL; + } +} + +void InStream::setTestCase(int testCase) { + if (testlibMode != _validator || mode != _input || !stdfile || this != &inf) + __testlib_fail("InStream::setTestCase can be used only for inf in validator-mode." + " Actually, prefer setTestCase function instead of InStream member"); + reader->setTestCase(testCase); +} + +std::vector InStream::getReadChars() { + if (testlibMode != _validator || mode != _input || !stdfile || this != &inf) + __testlib_fail("InStream::getReadChars can be used only for inf in validator-mode."); + return reader == NULL ? std::vector() : reader->getReadChars(); +} + +void setTestCase(int testCase) { + static bool first_run = true; + static bool zero_based = false; + + if (first_run && testCase == 0) + zero_based = true; + + if (zero_based) + testCase++; + + __testlib_hasTestCase = true; + __testlib_testCase = testCase; + + if (testlibMode == _validator) + inf.setTestCase(testCase); + + first_run = false; +} + +#ifdef __GNUC__ +__attribute__((const)) +#endif +int resultExitCode(TResult r) { + if (r == _ok) + return OK_EXIT_CODE; + if (r == _wa) + return WA_EXIT_CODE; + if (r == _pe) + return PE_EXIT_CODE; + if (r == _fail) + return FAIL_EXIT_CODE; + if (r == _dirt) + return DIRT_EXIT_CODE; + if (r == _points) + return POINTS_EXIT_CODE; + if (r == _unexpected_eof) +#ifdef ENABLE_UNEXPECTED_EOF + return UNEXPECTED_EOF_EXIT_CODE; +#else + return PE_EXIT_CODE; +#endif + if (r >= _partially) + return PC_BASE_EXIT_CODE + (r - _partially); + return FAIL_EXIT_CODE; +} + +void InStream::textColor( +#if !(defined(ON_WINDOWS) && (!defined(_MSC_VER) || _MSC_VER > 1400)) && defined(__GNUC__) + __attribute__((unused)) +#endif + WORD color +) { +#if defined(ON_WINDOWS) && (!defined(_MSC_VER) || _MSC_VER > 1400) + HANDLE handle = GetStdHandle(STD_OUTPUT_HANDLE); + SetConsoleTextAttribute(handle, color); +#endif +#if !defined(ON_WINDOWS) && defined(__GNUC__) + if (isatty(2)) + { + switch (color) + { + case LightRed: + fprintf(stderr, "\033[1;31m"); + break; + case LightCyan: + fprintf(stderr, "\033[1;36m"); + break; + case LightGreen: + fprintf(stderr, "\033[1;32m"); + break; + case LightYellow: + fprintf(stderr, "\033[1;33m"); + break; + case LightGray: + default: + fprintf(stderr, "\033[0m"); + } + } +#endif +} + +#ifdef TESTLIB_THROW_EXIT_EXCEPTION_INSTEAD_OF_EXIT +class exit_exception: public std::exception { +private: + int exitCode; +public: + exit_exception(int exitCode): exitCode(exitCode) {} + int getExitCode() { return exitCode; } +}; +#endif + +NORETURN void halt(int exitCode) { +#ifdef FOOTER + InStream::textColor(InStream::LightGray); + std::fprintf(stderr, "Checker: \"%s\"\n", checkerName.c_str()); + std::fprintf(stderr, "Exit code: %d\n", exitCode); + InStream::textColor(InStream::LightGray); +#endif + __testlib_exitCode = exitCode; +#ifdef TESTLIB_THROW_EXIT_EXCEPTION_INSTEAD_OF_EXIT + throw exit_exception(exitCode); +#endif + std::exit(exitCode); +} + +static bool __testlib_shouldCheckDirt(TResult result) { + return result == _ok || result == _points || result >= _partially; +} + +static std::string __testlib_appendMessage(const std::string &message, const std::string &extra) { + int openPos = -1, closePos = -1; + for (size_t i = 0; i < message.length(); i++) { + if (message[i] == InStream::OPEN_BRACKET) { + if (openPos == -1) + openPos = int(i); + else + openPos = INT_MAX; + } + if (message[i] == InStream::CLOSE_BRACKET) { + if (closePos == -1) + closePos = int(i); + else + closePos = INT_MAX; + } + } + if (openPos != -1 && openPos != INT_MAX + && closePos != -1 && closePos != INT_MAX + && openPos < closePos) { + size_t index = message.find(extra, openPos); + if (index == std::string::npos || int(index) >= closePos) { + std::string result(message); + result.insert(closePos, ", " + extra); + return result; + } + return message; + } + + return message + " " + InStream::OPEN_BRACKET + extra + InStream::CLOSE_BRACKET; +} + +static std::string __testlib_toPrintableMessage(const std::string &message) { + int openPos = -1, closePos = -1; + for (size_t i = 0; i < message.length(); i++) { + if (message[i] == InStream::OPEN_BRACKET) { + if (openPos == -1) + openPos = int(i); + else + openPos = INT_MAX; + } + if (message[i] == InStream::CLOSE_BRACKET) { + if (closePos == -1) + closePos = int(i); + else + closePos = INT_MAX; + } + } + if (openPos != -1 && openPos != INT_MAX + && closePos != -1 && closePos != INT_MAX + && openPos < closePos) { + std::string result(message); + result[openPos] = '('; + result[closePos] = ')'; + return result; + } + + return message; +} + +NORETURN void InStream::quit(TResult result, const char *msg) { + if (TestlibFinalizeGuard::alive) + testlibFinalizeGuard.quitCount++; + + std::string message(msg); + message = trim(message); + + if (__testlib_hasTestCase) { + if (result != _ok) + message = __testlib_appendMessage(message, "test case " + vtos(__testlib_testCase)); + else { + if (__testlib_testCase == 1) + message = __testlib_appendMessage(message, vtos(__testlib_testCase) + " test case"); + else + message = __testlib_appendMessage(message, vtos(__testlib_testCase) + " test cases"); + } + } + + // You can change maxMessageLength. + // Example: 'inf.maxMessageLength = 1024 * 1024;'. + if (message.length() > maxMessageLength) { + std::string warn = "message length exceeds " + vtos(maxMessageLength) + + ", the message is truncated: "; + message = warn + message.substr(0, maxMessageLength - warn.length()); + } + +#ifndef ENABLE_UNEXPECTED_EOF + if (result == _unexpected_eof) + result = _pe; +#endif + + if (testlibMode == _scorer && result != _fail) + quits(_fail, "Scorer should return points only. Don't use a quit function."); + + if (mode != _output && result != _fail) { + if (mode == _input && testlibMode == _validator && lastLine != -1) + quits(_fail, __testlib_appendMessage(__testlib_appendMessage(message, name), "line " + vtos(lastLine))); + else + quits(_fail, __testlib_appendMessage(message, name)); + } + + std::FILE *resultFile; + std::string errorName; + + if (__testlib_shouldCheckDirt(result)) { + if (testlibMode != _interactor && !ouf.seekEof()) + quit(_dirt, "Extra information in the output file"); + } + + int pctype = result - _partially; + bool isPartial = false; + + switch (result) { + case _ok: + errorName = "ok "; + quitscrS(LightGreen, errorName); + break; + case _wa: + errorName = "wrong answer "; + quitscrS(LightRed, errorName); + break; + case _pe: + errorName = "wrong output format "; + quitscrS(LightRed, errorName); + break; + case _fail: + errorName = "FAIL "; + quitscrS(LightRed, errorName); + break; + case _dirt: + errorName = "wrong output format "; + quitscrS(LightCyan, errorName); + result = _pe; + break; + case _points: + errorName = "points "; + quitscrS(LightYellow, errorName); + break; + case _unexpected_eof: + errorName = "unexpected eof "; + quitscrS(LightCyan, errorName); + break; + default: + if (result >= _partially) { + errorName = testlib_format_("partially correct (%d) ", pctype); + isPartial = true; + quitscrS(LightYellow, errorName); + } else + quit(_fail, "What is the code ??? "); + } + + if (resultName != "") { + resultFile = testlib_fopen_(resultName.c_str(), "w"); + if (resultFile == NULL) { + resultName = ""; + quit(_fail, "Can not write to the result file"); + } + if (appesMode) { + std::fprintf(resultFile, "", appesModeEncoding.c_str()); + if (isPartial) + std::fprintf(resultFile, "", + outcomes[(int) _partially].c_str(), pctype); + else { + if (result != _points) + std::fprintf(resultFile, "", outcomes[(int) result].c_str()); + else { + if (__testlib_points == std::numeric_limits::infinity()) + quit(_fail, "Expected points, but infinity found"); + std::string stringPoints = removeDoubleTrailingZeroes(testlib_format_("%.10f", __testlib_points)); + std::fprintf(resultFile, "", + outcomes[(int) result].c_str(), stringPoints.c_str()); + } + } + xmlSafeWrite(resultFile, __testlib_toPrintableMessage(message).c_str()); + std::fprintf(resultFile, "\n"); + } else + std::fprintf(resultFile, "%s", __testlib_toPrintableMessage(message).c_str()); + if (NULL == resultFile || fclose(resultFile) != 0) { + resultName = ""; + quit(_fail, "Can not write to the result file"); + } + } + + quitscr(LightGray, __testlib_toPrintableMessage(message).c_str()); + std::fprintf(stderr, "\n"); + + inf.close(); + ouf.close(); + ans.close(); + if (tout.is_open()) + tout.close(); + + textColor(LightGray); + + if (resultName != "") + std::fprintf(stderr, "See file to check exit message\n"); + + halt(resultExitCode(result)); +} + +#ifdef __GNUC__ +__attribute__ ((format (printf, 3, 4))) +#endif +NORETURN void InStream::quitf(TResult result, const char *msg, ...) { + FMT_TO_RESULT(msg, msg, message); + InStream::quit(result, message.c_str()); +} + +#ifdef __GNUC__ +__attribute__ ((format (printf, 4, 5))) +#endif +void InStream::quitif(bool condition, TResult result, const char *msg, ...) { + if (condition) { + FMT_TO_RESULT(msg, msg, message); + InStream::quit(result, message.c_str()); + } +} + +NORETURN void InStream::quits(TResult result, std::string msg) { + InStream::quit(result, msg.c_str()); +} + +void InStream::xmlSafeWrite(std::FILE *file, const char *msg) { + size_t lmsg = strlen(msg); + for (size_t i = 0; i < lmsg; i++) { + if (msg[i] == '&') { + std::fprintf(file, "%s", "&"); + continue; + } + if (msg[i] == '<') { + std::fprintf(file, "%s", "<"); + continue; + } + if (msg[i] == '>') { + std::fprintf(file, "%s", ">"); + continue; + } + if (msg[i] == '"') { + std::fprintf(file, "%s", """); + continue; + } + if (0 <= msg[i] && msg[i] <= 31) { + std::fprintf(file, "%c", '.'); + continue; + } + std::fprintf(file, "%c", msg[i]); + } +} + +void InStream::quitscrS(WORD color, std::string msg) { + quitscr(color, msg.c_str()); +} + +void InStream::quitscr(WORD color, const char *msg) { + if (resultName == "") { + textColor(color); + std::fprintf(stderr, "%s", msg); + textColor(LightGray); + } +} + +void InStream::reset(std::FILE *file) { + if (opened && stdfile) + quit(_fail, "Can't reset standard handle"); + + if (opened) + close(); + + if (!stdfile && NULL == file) + if (NULL == (file = testlib_fopen_(name.c_str(), "rb"))) { + if (mode == _output) + quits(_pe, std::string("Output file not found: \"") + name + "\""); + + if (mode == _answer) + quits(_fail, std::string("Answer file not found: \"") + name + "\""); + } + + if (NULL != file) { + opened = true; + __testlib_set_binary(file); + + if (stdfile) + reader = new FileInputStreamReader(file, name); + else + reader = new BufferedFileInputStreamReader(file, name); + } else { + opened = false; + reader = NULL; + } +} + +void InStream::init(std::string fileName, TMode mode) { + opened = false; + name = fileName; + stdfile = false; + this->mode = mode; + + std::ifstream stream; + stream.open(fileName.c_str(), std::ios::in); + if (stream.is_open()) { + std::streampos start = stream.tellg(); + stream.seekg(0, std::ios::end); + std::streampos end = stream.tellg(); + size_t fileSize = size_t(end - start); + stream.close(); + + // You can change maxFileSize. + // Example: 'inf.maxFileSize = 256 * 1024 * 1024;'. + if (fileSize > maxFileSize) + quitf(_pe, "File size exceeds %d bytes, size is %d", int(maxFileSize), int(fileSize)); + } + + reset(); +} + +void InStream::init(std::FILE *f, TMode mode) { + opened = false; + name = "untitled"; + this->mode = mode; + + if (f == stdin) + name = "stdin", stdfile = true; + if (f == stdout) + name = "stdout", stdfile = true; + if (f == stderr) + name = "stderr", stdfile = true; + + reset(f); +} + +void InStream::skipBom() { + const std::string utf8Bom = "\xEF\xBB\xBF"; + size_t index = 0; + while (index < utf8Bom.size() && curChar() == utf8Bom[index]) { + index++; + skipChar(); + } + if (index < utf8Bom.size()) { + while (index != 0) { + unreadChar(utf8Bom[index - 1]); + index--; + } + } +} + +char InStream::curChar() { + return char(reader->curChar()); +} + +char InStream::nextChar() { + return char(reader->nextChar()); +} + +char InStream::readChar() { + return nextChar(); +} + +char InStream::readChar(char c) { + lastLine = reader->getLine(); + char found = readChar(); + if (c != found) { + if (!isEoln(found)) + quit(_pe, ("Unexpected character '" + std::string(1, found) + "', but '" + std::string(1, c) + + "' expected").c_str()); + else + quit(_pe, ("Unexpected character " + ("#" + vtos(int(found))) + ", but '" + std::string(1, c) + + "' expected").c_str()); + } + return found; +} + +char InStream::readSpace() { + return readChar(' '); +} + +void InStream::unreadChar(char c) { + reader->unreadChar(c); +} + +void InStream::skipChar() { + reader->skipChar(); +} + +void InStream::skipBlanks() { + while (isBlanks(reader->curChar())) + reader->skipChar(); +} + +std::string InStream::readWord() { + readWordTo(_tmpReadToken); + return _tmpReadToken; +} + +void InStream::readWordTo(std::string &result) { + if (!strict) + skipBlanks(); + + lastLine = reader->getLine(); + int cur = reader->nextChar(); + + if (cur == EOFC) + quit(_unexpected_eof, "Unexpected end of file - token expected"); + + if (isBlanks(cur)) + quit(_pe, "Unexpected white-space - token expected"); + + result.clear(); + + while (!(isBlanks(cur) || cur == EOFC)) { + result += char(cur); + + // You can change maxTokenLength. + // Example: 'inf.maxTokenLength = 128 * 1024 * 1024;'. + if (result.length() > maxTokenLength) + quitf(_pe, "Length of token exceeds %d, token is '%s...'", int(maxTokenLength), + __testlib_part(result).c_str()); + + cur = reader->nextChar(); + } + + reader->unreadChar(cur); + + if (result.length() == 0) + quit(_unexpected_eof, "Unexpected end of file or white-space - token expected"); +} + +std::string InStream::readToken() { + return readWord(); +} + +void InStream::readTokenTo(std::string &result) { + readWordTo(result); +} + +#ifdef __GNUC__ +__attribute__((const)) +#endif +static std::string __testlib_part(const std::string &s) { + std::string t; + for (size_t i = 0; i < s.length(); i++) + if (s[i] != '\0') + t += s[i]; + else + t += '~'; + if (t.length() <= 64) + return t; + else + return t.substr(0, 30) + "..." + t.substr(s.length() - 31, 31); +} + +#define __testlib_readMany(readMany, readOne, typeName, space) \ + if (size < 0) \ + quit(_fail, #readMany ": size should be non-negative."); \ + if (size > 100000000) \ + quit(_fail, #readMany ": size should be at most 100000000."); \ + \ + std::vector result(size); \ + readManyIteration = indexBase; \ + \ + for (int i = 0; i < size; i++) \ + { \ + result[i] = readOne; \ + readManyIteration++; \ + if (strict && space && i + 1 < size) \ + readSpace(); \ + } \ + \ + readManyIteration = NO_INDEX; \ + return result; \ + + +std::string InStream::readWord(const pattern &p, const std::string &variableName) { + readWordTo(_tmpReadToken); + if (!p.matches(_tmpReadToken)) { + if (readManyIteration == NO_INDEX) { + if (variableName.empty()) + quit(_wa, + ("Token \"" + __testlib_part(_tmpReadToken) + "\" doesn't correspond to pattern \"" + p.src() + + "\"").c_str()); + else + quit(_wa, ("Token parameter [name=" + variableName + "] equals to \"" + __testlib_part(_tmpReadToken) + + "\", doesn't correspond to pattern \"" + p.src() + "\"").c_str()); + } else { + if (variableName.empty()) + quit(_wa, ("Token element [index=" + vtos(readManyIteration) + "] equals to \"" + + __testlib_part(_tmpReadToken) + "\" doesn't correspond to pattern \"" + p.src() + + "\"").c_str()); + else + quit(_wa, ("Token element " + variableName + "[" + vtos(readManyIteration) + "] equals to \"" + + __testlib_part(_tmpReadToken) + "\", doesn't correspond to pattern \"" + p.src() + + "\"").c_str()); + } + } + if (strict && !variableName.empty()) + validator.addVariable(variableName); + return _tmpReadToken; +} + +std::vector +InStream::readWords(int size, const pattern &p, const std::string &variablesName, int indexBase) { + if (strict && !variablesName.empty()) + validator.addVariable(variablesName); + __testlib_readMany(readWords, readWord(p, variablesName), std::string, true); +} + +std::vector InStream::readWords(int size, int indexBase) { + __testlib_readMany(readWords, readWord(), std::string, true); +} + +std::string InStream::readWord(const std::string &ptrn, const std::string &variableName) { + return readWord(pattern(ptrn), variableName); +} + +std::vector +InStream::readWords(int size, const std::string &ptrn, const std::string &variablesName, int indexBase) { + pattern p(ptrn); + if (strict && !variablesName.empty()) + validator.addVariable(variablesName); + __testlib_readMany(readWords, readWord(p, variablesName), std::string, true); +} + +std::string InStream::readToken(const pattern &p, const std::string &variableName) { + return readWord(p, variableName); +} + +std::vector +InStream::readTokens(int size, const pattern &p, const std::string &variablesName, int indexBase) { + if (strict && !variablesName.empty()) + validator.addVariable(variablesName); + __testlib_readMany(readTokens, readToken(p, variablesName), std::string, true); +} + +std::vector InStream::readTokens(int size, int indexBase) { + __testlib_readMany(readTokens, readToken(), std::string, true); +} + +std::string InStream::readToken(const std::string &ptrn, const std::string &variableName) { + return readWord(ptrn, variableName); +} + +std::vector +InStream::readTokens(int size, const std::string &ptrn, const std::string &variablesName, int indexBase) { + pattern p(ptrn); + if (strict && !variablesName.empty()) + validator.addVariable(variablesName); + __testlib_readMany(readTokens, readWord(p, variablesName), std::string, true); +} + +void InStream::readWordTo(std::string &result, const pattern &p, const std::string &variableName) { + readWordTo(result); + if (!p.matches(result)) { + if (variableName.empty()) + quit(_wa, ("Token \"" + __testlib_part(result) + "\" doesn't correspond to pattern \"" + p.src() + + "\"").c_str()); + else + quit(_wa, ("Token parameter [name=" + variableName + "] equals to \"" + __testlib_part(result) + + "\", doesn't correspond to pattern \"" + p.src() + "\"").c_str()); + } + if (strict && !variableName.empty()) + validator.addVariable(variableName); +} + +void InStream::readWordTo(std::string &result, const std::string &ptrn, const std::string &variableName) { + return readWordTo(result, pattern(ptrn), variableName); +} + +void InStream::readTokenTo(std::string &result, const pattern &p, const std::string &variableName) { + return readWordTo(result, p, variableName); +} + +void InStream::readTokenTo(std::string &result, const std::string &ptrn, const std::string &variableName) { + return readWordTo(result, ptrn, variableName); +} + +#ifdef __GNUC__ +__attribute__((pure)) +#endif +static inline bool equals(long long integer, const char *s) { + if (integer == LLONG_MIN) + return strcmp(s, "-9223372036854775808") == 0; + + if (integer == 0LL) + return strcmp(s, "0") == 0; + + size_t length = strlen(s); + + if (length == 0) + return false; + + if (integer < 0 && s[0] != '-') + return false; + + if (integer < 0) + s++, length--, integer = -integer; + + if (length == 0) + return false; + + while (integer > 0) { + int digit = int(integer % 10); + + if (s[length - 1] != '0' + digit) + return false; + + length--; + integer /= 10; + } + + return length == 0; +} + +#ifdef __GNUC__ +__attribute__((pure)) +#endif +static inline bool equals(unsigned long long integer, const char *s) { + if (integer == ULLONG_MAX) + return strcmp(s, "18446744073709551615") == 0; + + if (integer == 0ULL) + return strcmp(s, "0") == 0; + + size_t length = strlen(s); + + if (length == 0) + return false; + + while (integer > 0) { + int digit = int(integer % 10); + + if (s[length - 1] != '0' + digit) + return false; + + length--; + integer /= 10; + } + + return length == 0; +} + +static inline double stringToDouble(InStream &in, const char *buffer) { + double result; + + size_t length = strlen(buffer); + + int minusCount = 0; + int plusCount = 0; + int decimalPointCount = 0; + int digitCount = 0; + int eCount = 0; + + for (size_t i = 0; i < length; i++) { + if (('0' <= buffer[i] && buffer[i] <= '9') || buffer[i] == '.' + || buffer[i] == 'e' || buffer[i] == 'E' + || buffer[i] == '-' || buffer[i] == '+') { + if ('0' <= buffer[i] && buffer[i] <= '9') + digitCount++; + if (buffer[i] == 'e' || buffer[i] == 'E') + eCount++; + if (buffer[i] == '-') + minusCount++; + if (buffer[i] == '+') + plusCount++; + if (buffer[i] == '.') + decimalPointCount++; + } else + in.quit(_pe, ("Expected double, but \"" + __testlib_part(buffer) + "\" found").c_str()); + } + + // If for sure is not a number in standard notation or in e-notation. + if (digitCount == 0 || minusCount > 2 || plusCount > 2 || decimalPointCount > 1 || eCount > 1) + in.quit(_pe, ("Expected double, but \"" + __testlib_part(buffer) + "\" found").c_str()); + + char *suffix = new char[length + 1]; + std::memset(suffix, 0, length + 1); + int scanned; +#ifdef _MSC_VER + scanned = sscanf_s(buffer, "%lf%s", &result, suffix, (unsigned int)(length + 1)); +#else + scanned = std::sscanf(buffer, "%lf%s", &result, suffix); +#endif + bool empty = strlen(suffix) == 0; + delete[] suffix; + + if (scanned == 1 || (scanned == 2 && empty)) { + if (__testlib_isNaN(result)) + in.quit(_pe, ("Expected double, but \"" + __testlib_part(buffer) + "\" found").c_str()); + return result; + } else + in.quit(_pe, ("Expected double, but \"" + __testlib_part(buffer) + "\" found").c_str()); +} + +static inline double stringToDouble(InStream &in, const std::string& buffer) { + for (size_t i = 0; i < buffer.length(); i++) + if (buffer[i] == '\0') + in.quit(_pe, ("Expected double, but \"" + __testlib_part(buffer) + "\" found (it contains \\0)").c_str()); + return stringToDouble(in, buffer.c_str()); +} + +static inline double stringToStrictDouble(InStream &in, const char *buffer, + int minAfterPointDigitCount, int maxAfterPointDigitCount) { + if (minAfterPointDigitCount < 0) + in.quit(_fail, "stringToStrictDouble: minAfterPointDigitCount should be non-negative."); + + if (minAfterPointDigitCount > maxAfterPointDigitCount) + in.quit(_fail, + "stringToStrictDouble: minAfterPointDigitCount should be less or equal to maxAfterPointDigitCount."); + + double result; + + size_t length = strlen(buffer); + + if (length == 0 || length > 1000) + in.quit(_pe, ("Expected strict double, but \"" + __testlib_part(buffer) + "\" found").c_str()); + + if (buffer[0] != '-' && (buffer[0] < '0' || buffer[0] > '9')) + in.quit(_pe, ("Expected strict double, but \"" + __testlib_part(buffer) + "\" found").c_str()); + + int pointPos = -1; + for (size_t i = 1; i + 1 < length; i++) { + if (buffer[i] == '.') { + if (pointPos > -1) + in.quit(_pe, ("Expected strict double, but \"" + __testlib_part(buffer) + "\" found").c_str()); + pointPos = int(i); + } + if (buffer[i] != '.' && (buffer[i] < '0' || buffer[i] > '9')) + in.quit(_pe, ("Expected strict double, but \"" + __testlib_part(buffer) + "\" found").c_str()); + } + + if (buffer[length - 1] < '0' || buffer[length - 1] > '9') + in.quit(_pe, ("Expected strict double, but \"" + __testlib_part(buffer) + "\" found").c_str()); + + int afterDigitsCount = (pointPos == -1 ? 0 : int(length) - pointPos - 1); + if (afterDigitsCount < minAfterPointDigitCount || afterDigitsCount > maxAfterPointDigitCount) + in.quit(_pe, ("Expected strict double with number of digits after point in range [" + + vtos(minAfterPointDigitCount) + + "," + + vtos(maxAfterPointDigitCount) + + "], but \"" + __testlib_part(buffer) + "\" found").c_str() + ); + + int firstDigitPos = -1; + for (size_t i = 0; i < length; i++) + if (buffer[i] >= '0' && buffer[i] <= '9') { + firstDigitPos = int(i); + break; + } + + if (firstDigitPos > 1 || firstDigitPos == -1) + in.quit(_pe, ("Expected strict double, but \"" + __testlib_part(buffer) + "\" found").c_str()); + + if (buffer[firstDigitPos] == '0' && firstDigitPos + 1 < int(length) + && buffer[firstDigitPos + 1] >= '0' && buffer[firstDigitPos + 1] <= '9') + in.quit(_pe, ("Expected strict double, but \"" + __testlib_part(buffer) + "\" found").c_str()); + + char *suffix = new char[length + 1]; + std::memset(suffix, 0, length + 1); + int scanned; +#ifdef _MSC_VER + scanned = sscanf_s(buffer, "%lf%s", &result, suffix, (unsigned int)(length + 1)); +#else + scanned = std::sscanf(buffer, "%lf%s", &result, suffix); +#endif + bool empty = strlen(suffix) == 0; + delete[] suffix; + + if (scanned == 1 || (scanned == 2 && empty)) { + if (__testlib_isNaN(result) || __testlib_isInfinite(result)) + in.quit(_pe, ("Expected double, but \"" + __testlib_part(buffer) + "\" found").c_str()); + if (buffer[0] == '-' && result >= 0) + in.quit(_pe, ("Redundant minus in \"" + __testlib_part(buffer) + "\" found").c_str()); + return result; + } else + in.quit(_pe, ("Expected double, but \"" + __testlib_part(buffer) + "\" found").c_str()); +} + +static inline double stringToStrictDouble(InStream &in, const std::string& buffer, + int minAfterPointDigitCount, int maxAfterPointDigitCount) { + for (size_t i = 0; i < buffer.length(); i++) + if (buffer[i] == '\0') + in.quit(_pe, ("Expected double, but \"" + __testlib_part(buffer) + "\" found (it contains \\0)").c_str()); + return stringToStrictDouble(in, buffer.c_str(), minAfterPointDigitCount, maxAfterPointDigitCount); +} + +static inline long long stringToLongLong(InStream &in, const char *buffer) { + size_t length = strlen(buffer); + if (length == 0 || length > 20) + in.quit(_pe, ("Expected integer, but \"" + __testlib_part(buffer) + "\" found").c_str()); + + bool has_minus = (length > 1 && buffer[0] == '-'); + int zeroes = 0; + bool processingZeroes = true; + + for (int i = (has_minus ? 1 : 0); i < int(length); i++) { + if (buffer[i] == '0' && processingZeroes) + zeroes++; + else + processingZeroes = false; + + if (buffer[i] < '0' || buffer[i] > '9') + in.quit(_pe, ("Expected integer, but \"" + __testlib_part(buffer) + "\" found").c_str()); + } + + long long int result; + try { + result = std::stoll(buffer); + } catch (const std::exception&) { + in.quit(_pe, ("Expected integer, but \"" + __testlib_part(buffer) + "\" found").c_str()); + } catch (...) { + in.quit(_pe, ("Expected integer, but \"" + __testlib_part(buffer) + "\" found").c_str()); + } + + if ((zeroes > 0 && (result != 0 || has_minus)) || zeroes > 1) + in.quit(_pe, ("Expected integer, but \"" + __testlib_part(buffer) + "\" found").c_str()); + + return result; +} + +static inline long long stringToLongLong(InStream &in, const std::string& buffer) { + for (size_t i = 0; i < buffer.length(); i++) + if (buffer[i] == '\0') + in.quit(_pe, ("Expected integer, but \"" + __testlib_part(buffer) + "\" found (it contains \\0)").c_str()); + return stringToLongLong(in, buffer.c_str()); +} + +static inline unsigned long long stringToUnsignedLongLong(InStream &in, const char *buffer) { + size_t length = strlen(buffer); + + if (length == 0 || length > 20) + in.quit(_pe, ("Expected unsigned integer, but \"" + __testlib_part(buffer) + "\" found").c_str()); + if (length > 1 && buffer[0] == '0') + in.quit(_pe, ("Expected unsigned integer, but \"" + __testlib_part(buffer) + "\" found").c_str()); + + for (int i = 0; i < int(length); i++) { + if (buffer[i] < '0' || buffer[i] > '9') + in.quit(_pe, ("Expected unsigned integer, but \"" + __testlib_part(buffer) + "\" found").c_str()); + } + + unsigned long long result; + try { + result = std::stoull(buffer); + } catch (const std::exception&) { + in.quit(_pe, ("Expected unsigned integer, but \"" + __testlib_part(buffer) + "\" found").c_str()); + } catch (...) { + in.quit(_pe, ("Expected unsigned integer, but \"" + __testlib_part(buffer) + "\" found").c_str()); + } + + return result; +} + +static inline long long stringToUnsignedLongLong(InStream &in, const std::string& buffer) { + for (size_t i = 0; i < buffer.length(); i++) + if (buffer[i] == '\0') + in.quit(_pe, ("Expected unsigned integer, but \"" + __testlib_part(buffer) + "\" found (it contains \\0)").c_str()); + return stringToUnsignedLongLong(in, buffer.c_str()); +} + +int InStream::readInteger() { + if (!strict && seekEof()) + quit(_unexpected_eof, "Unexpected end of file - int32 expected"); + + readWordTo(_tmpReadToken); + + long long value = stringToLongLong(*this, _tmpReadToken); + if (value < INT_MIN || value > INT_MAX) + quit(_pe, ("Expected int32, but \"" + __testlib_part(_tmpReadToken) + "\" found").c_str()); + + return int(value); +} + +long long InStream::readLong() { + if (!strict && seekEof()) + quit(_unexpected_eof, "Unexpected end of file - int64 expected"); + + readWordTo(_tmpReadToken); + + return stringToLongLong(*this, _tmpReadToken); +} + +unsigned long long InStream::readUnsignedLong() { + if (!strict && seekEof()) + quit(_unexpected_eof, "Unexpected end of file - int64 expected"); + + readWordTo(_tmpReadToken); + + return stringToUnsignedLongLong(*this, _tmpReadToken); +} + +long long InStream::readLong(long long minv, long long maxv, const std::string &variableName) { + long long result = readLong(); + + if (result < minv || result > maxv) { + if (readManyIteration == NO_INDEX) { + if (variableName.empty()) + quit(_wa, ("Integer " + vtos(result) + " violates the range [" + toHumanReadableString(minv) + ", " + toHumanReadableString(maxv) + + "]").c_str()); + else + quit(_wa, ("Integer parameter [name=" + std::string(variableName) + "] equals to " + vtos(result) + + ", violates the range [" + toHumanReadableString(minv) + ", " + toHumanReadableString(maxv) + "]").c_str()); + } else { + if (variableName.empty()) + quit(_wa, ("Integer element [index=" + vtos(readManyIteration) + "] equals to " + vtos(result) + + ", violates the range [" + toHumanReadableString(minv) + ", " + toHumanReadableString(maxv) + "]").c_str()); + else + quit(_wa, + ("Integer element " + std::string(variableName) + "[" + vtos(readManyIteration) + "] equals to " + + vtos(result) + ", violates the range [" + toHumanReadableString(minv) + ", " + toHumanReadableString(maxv) + "]").c_str()); + } + } + + if (strict && !variableName.empty()) { + validator.addBoundsHit(variableName, ValidatorBoundsHit(minv == result, maxv == result)); + validator.adjustConstantBounds(variableName, minv, maxv); + validator.addVariable(variableName); + } + + return result; +} + +std::vector +InStream::readLongs(int size, long long minv, long long maxv, const std::string &variablesName, int indexBase) { + if (strict && !variablesName.empty()) + validator.addVariable(variablesName); + __testlib_readMany(readLongs, readLong(minv, maxv, variablesName), long long, true) +} + +std::vector InStream::readLongs(int size, int indexBase) { + __testlib_readMany(readLongs, readLong(), long long, true) +} + +unsigned long long +InStream::readUnsignedLong(unsigned long long minv, unsigned long long maxv, const std::string &variableName) { + unsigned long long result = readUnsignedLong(); + + if (result < minv || result > maxv) { + if (readManyIteration == NO_INDEX) { + if (variableName.empty()) + quit(_wa, + ("Unsigned integer " + vtos(result) + " violates the range [" + toHumanReadableString(minv) + ", " + toHumanReadableString(maxv) + + "]").c_str()); + else + quit(_wa, + ("Unsigned integer parameter [name=" + std::string(variableName) + "] equals to " + vtos(result) + + ", violates the range [" + toHumanReadableString(minv) + ", " + toHumanReadableString(maxv) + "]").c_str()); + } else { + if (variableName.empty()) + quit(_wa, + ("Unsigned integer element [index=" + vtos(readManyIteration) + "] equals to " + vtos(result) + + ", violates the range [" + toHumanReadableString(minv) + ", " + toHumanReadableString(maxv) + "]").c_str()); + else + quit(_wa, ("Unsigned integer element " + std::string(variableName) + "[" + vtos(readManyIteration) + + "] equals to " + vtos(result) + ", violates the range [" + toHumanReadableString(minv) + ", " + toHumanReadableString(maxv) + + "]").c_str()); + } + } + + if (strict && !variableName.empty()) { + validator.addBoundsHit(variableName, ValidatorBoundsHit(minv == result, maxv == result)); + validator.adjustConstantBounds(variableName, minv, maxv); + validator.addVariable(variableName); + } + + return result; +} + +std::vector InStream::readUnsignedLongs(int size, unsigned long long minv, unsigned long long maxv, + const std::string &variablesName, int indexBase) { + if (strict && !variablesName.empty()) + validator.addVariable(variablesName); + __testlib_readMany(readUnsignedLongs, readUnsignedLong(minv, maxv, variablesName), unsigned long long, true) +} + +std::vector InStream::readUnsignedLongs(int size, int indexBase) { + __testlib_readMany(readUnsignedLongs, readUnsignedLong(), unsigned long long, true) +} + +unsigned long long +InStream::readLong(unsigned long long minv, unsigned long long maxv, const std::string &variableName) { + return readUnsignedLong(minv, maxv, variableName); +} + +int InStream::readInt() { + return readInteger(); +} + +int InStream::readInt(int minv, int maxv, const std::string &variableName) { + int result = readInt(); + + if (result < minv || result > maxv) { + if (readManyIteration == NO_INDEX) { + if (variableName.empty()) + quit(_wa, ("Integer " + vtos(result) + " violates the range [" + toHumanReadableString(minv) + ", " + toHumanReadableString(maxv) + + "]").c_str()); + else + quit(_wa, ("Integer parameter [name=" + std::string(variableName) + "] equals to " + vtos(result) + + ", violates the range [" + toHumanReadableString(minv) + ", " + toHumanReadableString(maxv) + "]").c_str()); + } else { + if (variableName.empty()) + quit(_wa, ("Integer element [index=" + vtos(readManyIteration) + "] equals to " + vtos(result) + + ", violates the range [" + toHumanReadableString(minv) + ", " + toHumanReadableString(maxv) + "]").c_str()); + else + quit(_wa, + ("Integer element " + std::string(variableName) + "[" + vtos(readManyIteration) + "] equals to " + + vtos(result) + ", violates the range [" + toHumanReadableString(minv) + ", " + toHumanReadableString(maxv) + "]").c_str()); + } + } + + if (strict && !variableName.empty()) { + validator.addBoundsHit(variableName, ValidatorBoundsHit(minv == result, maxv == result)); + validator.adjustConstantBounds(variableName, minv, maxv); + validator.addVariable(variableName); + } + + return result; +} + +int InStream::readInteger(int minv, int maxv, const std::string &variableName) { + return readInt(minv, maxv, variableName); +} + +std::vector InStream::readInts(int size, int minv, int maxv, const std::string &variablesName, int indexBase) { + if (strict && !variablesName.empty()) + validator.addVariable(variablesName); + __testlib_readMany(readInts, readInt(minv, maxv, variablesName), int, true) +} + +std::vector InStream::readInts(int size, int indexBase) { + __testlib_readMany(readInts, readInt(), int, true) +} + +std::vector InStream::readIntegers(int size, int minv, int maxv, const std::string &variablesName, int indexBase) { + if (strict && !variablesName.empty()) + validator.addVariable(variablesName); + __testlib_readMany(readIntegers, readInt(minv, maxv, variablesName), int, true) +} + +std::vector InStream::readIntegers(int size, int indexBase) { + __testlib_readMany(readIntegers, readInt(), int, true) +} + +double InStream::readReal() { + if (!strict && seekEof()) + quit(_unexpected_eof, "Unexpected end of file - double expected"); + + return stringToDouble(*this, readWord()); +} + +double InStream::readDouble() { + return readReal(); +} + +double InStream::readReal(double minv, double maxv, const std::string &variableName) { + double result = readReal(); + + if (result < minv || result > maxv) { + if (readManyIteration == NO_INDEX) { + if (variableName.empty()) + quit(_wa, ("Double " + vtos(result) + " violates the range [" + vtos(minv) + ", " + vtos(maxv) + + "]").c_str()); + else + quit(_wa, ("Double parameter [name=" + std::string(variableName) + "] equals to " + vtos(result) + + ", violates the range [" + vtos(minv) + ", " + vtos(maxv) + "]").c_str()); + } else { + if (variableName.empty()) + quit(_wa, ("Double element [index=" + vtos(readManyIteration) + "] equals to " + vtos(result) + + ", violates the range [" + vtos(minv) + ", " + vtos(maxv) + "]").c_str()); + else + quit(_wa, + ("Double element " + std::string(variableName) + "[" + vtos(readManyIteration) + "] equals to " + + vtos(result) + ", violates the range [" + vtos(minv) + ", " + vtos(maxv) + "]").c_str()); + } + } + + if (strict && !variableName.empty()) { + validator.addBoundsHit(variableName, ValidatorBoundsHit( + doubleDelta(minv, result) < ValidatorBoundsHit::EPS, + doubleDelta(maxv, result) < ValidatorBoundsHit::EPS + )); + validator.adjustConstantBounds(variableName, minv, maxv); + validator.addVariable(variableName); + } + + return result; +} + +std::vector +InStream::readReals(int size, double minv, double maxv, const std::string &variablesName, int indexBase) { + if (strict && !variablesName.empty()) + validator.addVariable(variablesName); + __testlib_readMany(readReals, readReal(minv, maxv, variablesName), double, true) +} + +std::vector InStream::readReals(int size, int indexBase) { + __testlib_readMany(readReals, readReal(), double, true) +} + +double InStream::readDouble(double minv, double maxv, const std::string &variableName) { + return readReal(minv, maxv, variableName); +} + +std::vector +InStream::readDoubles(int size, double minv, double maxv, const std::string &variablesName, int indexBase) { + if (strict && !variablesName.empty()) + validator.addVariable(variablesName); + __testlib_readMany(readDoubles, readDouble(minv, maxv, variablesName), double, true) +} + +std::vector InStream::readDoubles(int size, int indexBase) { + __testlib_readMany(readDoubles, readDouble(), double, true) +} + +double InStream::readStrictReal(double minv, double maxv, + int minAfterPointDigitCount, int maxAfterPointDigitCount, + const std::string &variableName) { + if (!strict && seekEof()) + quit(_unexpected_eof, "Unexpected end of file - strict double expected"); + + double result = stringToStrictDouble(*this, readWord(), minAfterPointDigitCount, maxAfterPointDigitCount); + + if (result < minv || result > maxv) { + if (readManyIteration == NO_INDEX) { + if (variableName.empty()) + quit(_wa, ("Strict double " + vtos(result) + " violates the range [" + vtos(minv) + ", " + vtos(maxv) + + "]").c_str()); + else + quit(_wa, + ("Strict double parameter [name=" + std::string(variableName) + "] equals to " + vtos(result) + + ", violates the range [" + vtos(minv) + ", " + vtos(maxv) + "]").c_str()); + } else { + if (variableName.empty()) + quit(_wa, ("Strict double element [index=" + vtos(readManyIteration) + "] equals to " + vtos(result) + + ", violates the range [" + vtos(minv) + ", " + vtos(maxv) + "]").c_str()); + else + quit(_wa, ("Strict double element " + std::string(variableName) + "[" + vtos(readManyIteration) + + "] equals to " + vtos(result) + ", violates the range [" + vtos(minv) + ", " + vtos(maxv) + + "]").c_str()); + } + } + + if (strict && !variableName.empty()) { + validator.addBoundsHit(variableName, ValidatorBoundsHit( + doubleDelta(minv, result) < ValidatorBoundsHit::EPS, + doubleDelta(maxv, result) < ValidatorBoundsHit::EPS + )); + validator.adjustConstantBounds(variableName, minv, maxv); + validator.addVariable(variableName); + } + + return result; +} + +std::vector InStream::readStrictReals(int size, double minv, double maxv, + int minAfterPointDigitCount, int maxAfterPointDigitCount, + const std::string &variablesName, int indexBase) { + if (strict && !variablesName.empty()) + validator.addVariable(variablesName); + __testlib_readMany(readStrictReals, + readStrictReal(minv, maxv, minAfterPointDigitCount, maxAfterPointDigitCount, variablesName), + double, true) +} + +double InStream::readStrictDouble(double minv, double maxv, + int minAfterPointDigitCount, int maxAfterPointDigitCount, + const std::string &variableName) { + return readStrictReal(minv, maxv, + minAfterPointDigitCount, maxAfterPointDigitCount, + variableName); +} + +std::vector InStream::readStrictDoubles(int size, double minv, double maxv, + int minAfterPointDigitCount, int maxAfterPointDigitCount, + const std::string &variablesName, int indexBase) { + if (strict && !variablesName.empty()) + validator.addVariable(variablesName); + __testlib_readMany(readStrictDoubles, + readStrictDouble(minv, maxv, minAfterPointDigitCount, maxAfterPointDigitCount, variablesName), + double, true) +} + +bool InStream::eof() { + if (!strict && NULL == reader) + return true; + + return reader->eof(); +} + +bool InStream::seekEof() { + if (!strict && NULL == reader) + return true; + skipBlanks(); + return eof(); +} + +bool InStream::eoln() { + if (!strict && NULL == reader) + return true; + + int c = reader->nextChar(); + + if (!strict) { + if (c == EOFC) + return true; + + if (c == CR) { + c = reader->nextChar(); + + if (c != LF) { + reader->unreadChar(c); + reader->unreadChar(CR); + return false; + } else + return true; + } + + if (c == LF) + return true; + + reader->unreadChar(c); + return false; + } else { + bool returnCr = false; + +#if (defined(ON_WINDOWS) && !defined(FOR_LINUX)) || defined(FOR_WINDOWS) + if (c != CR) { + reader->unreadChar(c); + return false; + } else { + if (!returnCr) + returnCr = true; + c = reader->nextChar(); + } +#endif + if (c != LF) { + reader->unreadChar(c); + if (returnCr) + reader->unreadChar(CR); + return false; + } + + return true; + } +} + +void InStream::readEoln() { + lastLine = reader->getLine(); + if (!eoln()) + quit(_pe, "Expected EOLN"); +} + +void InStream::readEof() { + lastLine = reader->getLine(); + if (!eof()) + quit(_pe, "Expected EOF"); + + if (TestlibFinalizeGuard::alive && this == &inf) + testlibFinalizeGuard.readEofCount++; +} + +bool InStream::seekEoln() { + if (!strict && NULL == reader) + return true; + + int cur; + do { + cur = reader->nextChar(); + } while (cur == SPACE || cur == TAB); + + reader->unreadChar(cur); + return eoln(); +} + +void InStream::nextLine() { + readLine(); +} + +void InStream::readStringTo(std::string &result) { + if (NULL == reader) + quit(_pe, "Expected line"); + + result.clear(); + + for (;;) { + int cur = reader->curChar(); + + if (cur == LF || cur == EOFC) + break; + + if (cur == CR) { + cur = reader->nextChar(); + if (reader->curChar() == LF) { + reader->unreadChar(cur); + break; + } + } + + lastLine = reader->getLine(); + result += char(reader->nextChar()); + } + + if (strict) + readEoln(); + else + eoln(); +} + +std::string InStream::readString() { + readStringTo(_tmpReadToken); + return _tmpReadToken; +} + +std::vector InStream::readStrings(int size, int indexBase) { + __testlib_readMany(readStrings, readString(), std::string, false) +} + +void InStream::readStringTo(std::string &result, const pattern &p, const std::string &variableName) { + readStringTo(result); + if (!p.matches(result)) { + if (readManyIteration == NO_INDEX) { + if (variableName.empty()) + quit(_wa, ("Line \"" + __testlib_part(result) + "\" doesn't correspond to pattern \"" + p.src() + + "\"").c_str()); + else + quit(_wa, ("Line [name=" + variableName + "] equals to \"" + __testlib_part(result) + + "\", doesn't correspond to pattern \"" + p.src() + "\"").c_str()); + } else { + if (variableName.empty()) + quit(_wa, + ("Line element [index=" + vtos(readManyIteration) + "] equals to \"" + __testlib_part(result) + + "\" doesn't correspond to pattern \"" + p.src() + "\"").c_str()); + else + quit(_wa, + ("Line element " + std::string(variableName) + "[" + vtos(readManyIteration) + "] equals to \"" + + __testlib_part(result) + "\", doesn't correspond to pattern \"" + p.src() + "\"").c_str()); + } + } + if (strict && !variableName.empty()) + validator.addVariable(variableName); +} + +void InStream::readStringTo(std::string &result, const std::string &ptrn, const std::string &variableName) { + readStringTo(result, pattern(ptrn), variableName); +} + +std::string InStream::readString(const pattern &p, const std::string &variableName) { + readStringTo(_tmpReadToken, p, variableName); + return _tmpReadToken; +} + +std::vector +InStream::readStrings(int size, const pattern &p, const std::string &variablesName, int indexBase) { + if (strict && !variablesName.empty()) + validator.addVariable(variablesName); + __testlib_readMany(readStrings, readString(p, variablesName), std::string, false) +} + +std::string InStream::readString(const std::string &ptrn, const std::string &variableName) { + readStringTo(_tmpReadToken, ptrn, variableName); + return _tmpReadToken; +} + +std::vector +InStream::readStrings(int size, const std::string &ptrn, const std::string &variablesName, int indexBase) { + pattern p(ptrn); + if (strict && !variablesName.empty()) + validator.addVariable(variablesName); + __testlib_readMany(readStrings, readString(p, variablesName), std::string, false) +} + +void InStream::readLineTo(std::string &result) { + readStringTo(result); +} + +std::string InStream::readLine() { + return readString(); +} + +std::vector InStream::readLines(int size, int indexBase) { + __testlib_readMany(readLines, readString(), std::string, false) +} + +void InStream::readLineTo(std::string &result, const pattern &p, const std::string &variableName) { + readStringTo(result, p, variableName); +} + +void InStream::readLineTo(std::string &result, const std::string &ptrn, const std::string &variableName) { + readStringTo(result, ptrn, variableName); +} + +std::string InStream::readLine(const pattern &p, const std::string &variableName) { + return readString(p, variableName); +} + +std::vector +InStream::readLines(int size, const pattern &p, const std::string &variablesName, int indexBase) { + if (strict && !variablesName.empty()) + validator.addVariable(variablesName); + __testlib_readMany(readLines, readString(p, variablesName), std::string, false) +} + +std::string InStream::readLine(const std::string &ptrn, const std::string &variableName) { + return readString(ptrn, variableName); +} + +std::vector +InStream::readLines(int size, const std::string &ptrn, const std::string &variablesName, int indexBase) { + pattern p(ptrn); + if (strict && !variablesName.empty()) + validator.addVariable(variablesName); + __testlib_readMany(readLines, readString(p, variablesName), std::string, false) +} + +#ifdef __GNUC__ +__attribute__ ((format (printf, 3, 4))) +#endif +void InStream::ensuref(bool cond, const char *format, ...) { + if (!cond) { + FMT_TO_RESULT(format, format, message); + this->__testlib_ensure(cond, message); + } +} + +void InStream::__testlib_ensure(bool cond, std::string message) { + if (!cond) + this->quit(_wa, message.c_str()); +} + +void InStream::close() { + if (NULL != reader) { + reader->close(); + delete reader; + reader = NULL; + } + + opened = false; +} + +NORETURN void quit(TResult result, const std::string &msg) { + ouf.quit(result, msg.c_str()); +} + +NORETURN void quit(TResult result, const char *msg) { + ouf.quit(result, msg); +} + +double __testlib_preparePoints(double points_) { + volatile double points = points_; + if (__testlib_isNaN(points)) + quit(_fail, "Parameter 'points' can't be nan"); + if (__testlib_isInfinite(points)) + quit(_fail, "Parameter 'points' can't be infinite"); + if (points < -1E-8) + quit(_fail, "Parameter 'points' can't be negative"); + if (points <= 0.0) + points = +0.0; + if (points > 1E6 + 1E-8) + quit(_fail, "Parameter 'points' can't be greater than 1E6"); + if (points >= 1E6) + points = 1E6; + return points; +} + +NORETURN void __testlib_quitp(double points, const char *message) { + __testlib_points = __testlib_preparePoints(points); + std::string stringPoints = removeDoubleTrailingZeroes(testlib_format_("%.10f", __testlib_points)); + + std::string quitMessage; + if (NULL == message || 0 == strlen(message)) + quitMessage = stringPoints; + else + quitMessage = stringPoints + " " + message; + + quit(_points, quitMessage.c_str()); +} + +NORETURN void __testlib_quitp(int points, const char *message) { + __testlib_points = __testlib_preparePoints(points); + std::string stringPoints = testlib_format_("%d", points); + + std::string quitMessage; + if (NULL == message || 0 == strlen(message)) + quitMessage = stringPoints; + else + quitMessage = stringPoints + " " + message; + + quit(_points, quitMessage.c_str()); +} + +NORETURN void quitp(float points, const std::string &message = "") { + __testlib_quitp(double(points), message.c_str()); +} + +NORETURN void quitp(double points, const std::string &message = "") { + __testlib_quitp(points, message.c_str()); +} + +NORETURN void quitp(long double points, const std::string &message = "") { + __testlib_quitp(double(points), message.c_str()); +} + +NORETURN void quitp(int points, const std::string &message = "") { + __testlib_quitp(points, message.c_str()); +} + +NORETURN void quitpi(const std::string &points_info, const std::string &message = "") { + if (points_info.find(' ') != std::string::npos) + quit(_fail, "Parameter 'points_info' can't contain spaces"); + if (message.empty()) + quit(_points, ("points_info=" + points_info).c_str()); + else + quit(_points, ("points_info=" + points_info + " " + message).c_str()); +} + +template +#ifdef __GNUC__ +__attribute__ ((format (printf, 2, 3))) +#endif +NORETURN void quitp(F points, const char *format, ...) { + FMT_TO_RESULT(format, format, message); + quitp(points, message); +} + +#ifdef __GNUC__ +__attribute__ ((format (printf, 2, 3))) +#endif +NORETURN void quitf(TResult result, const char *format, ...) { + FMT_TO_RESULT(format, format, message); + quit(result, message); +} + +#ifdef __GNUC__ +__attribute__ ((format (printf, 3, 4))) +#endif +void quitif(bool condition, TResult result, const char *format, ...) { + if (condition) { + FMT_TO_RESULT(format, format, message); + quit(result, message); + } +} + +NORETURN void __testlib_help() { + InStream::textColor(InStream::LightCyan); + std::fprintf(stderr, "TESTLIB %s, https://github.com/MikeMirzayanov/testlib/ ", VERSION); + std::fprintf(stderr, "by Mike Mirzayanov, copyright(c) 2005-2020\n"); + std::fprintf(stderr, "Checker name: \"%s\"\n", checkerName.c_str()); + InStream::textColor(InStream::LightGray); + + std::fprintf(stderr, "\n"); + std::fprintf(stderr, "Latest features: \n"); + for (size_t i = 0; i < sizeof(latestFeatures) / sizeof(char *); i++) { + std::fprintf(stderr, "*) %s\n", latestFeatures[i]); + } + std::fprintf(stderr, "\n"); + + std::fprintf(stderr, "Program must be run with the following arguments: \n"); + std::fprintf(stderr, " [--testset testset] [--group group] [ [<-appes>]]\n\n"); + + __testlib_exitCode = FAIL_EXIT_CODE; + std::exit(FAIL_EXIT_CODE); +} + +static void __testlib_ensuresPreconditions() { + // testlib assumes: sizeof(int) = 4. + __TESTLIB_STATIC_ASSERT(sizeof(int) == 4); + + // testlib assumes: INT_MAX == 2147483647. + __TESTLIB_STATIC_ASSERT(INT_MAX == 2147483647); + + // testlib assumes: sizeof(long long) = 8. + __TESTLIB_STATIC_ASSERT(sizeof(long long) == 8); + + // testlib assumes: sizeof(double) = 8. + __TESTLIB_STATIC_ASSERT(sizeof(double) == 8); + + // testlib assumes: no -ffast-math. + if (!__testlib_isNaN(+__testlib_nan())) + quit(_fail, "Function __testlib_isNaN is not working correctly: possible reason is '-ffast-math'"); + if (!__testlib_isNaN(-__testlib_nan())) + quit(_fail, "Function __testlib_isNaN is not working correctly: possible reason is '-ffast-math'"); +} + +std::string __testlib_testset; + +std::string getTestset() { + return __testlib_testset; +} + +std::string __testlib_group; + +std::string getGroup() { + return __testlib_group; +} + +static void __testlib_set_testset_and_group(int argc, char* argv[]) { + for (int i = 1; i < argc; i++) { + if (!strcmp("--testset", argv[i])) { + if (i + 1 < argc && strlen(argv[i + 1]) > 0) + __testlib_testset = argv[++i]; + else + quit(_fail, std::string("Expected non-empty testset after --testset command line parameter")); + } else if (!strcmp("--group", argv[i])) { + if (i + 1 < argc) + __testlib_group = argv[++i]; + else + quit(_fail, std::string("Expected group after --group command line parameter")); + } + } +} + +void registerGen(int argc, char *argv[], int randomGeneratorVersion) { + if (randomGeneratorVersion < 0 || randomGeneratorVersion > 1) + quitf(_fail, "Random generator version is expected to be 0 or 1."); + random_t::version = randomGeneratorVersion; + + __testlib_ensuresPreconditions(); + TestlibFinalizeGuard::registered = true; + + testlibMode = _generator; + __testlib_set_binary(stdin); + rnd.setSeed(argc, argv); + +#if __cplusplus > 199711L || defined(_MSC_VER) + prepareOpts(argc, argv); +#endif +} + +#ifdef USE_RND_AS_BEFORE_087 +void registerGen(int argc, char* argv[]) +{ + registerGen(argc, argv, 0); +} +#else +#ifdef __GNUC__ +#if (__GNUC__ > 4) || ((__GNUC__ == 4) && (__GNUC_MINOR__ > 4)) +__attribute__ ((deprecated("Use registerGen(argc, argv, 0) or registerGen(argc, argv, 1)." +" The third parameter stands for the random generator version." +" If you are trying to compile old generator use macro -DUSE_RND_AS_BEFORE_087 or registerGen(argc, argv, 0)." +" Version 1 has been released on Spring, 2013. Use it to write new generators."))) +#else +__attribute__ ((deprecated)) +#endif +#endif +#ifdef _MSC_VER +__declspec(deprecated("Use registerGen(argc, argv, 0) or registerGen(argc, argv, 1)." + " The third parameter stands for the random generator version." + " If you are trying to compile old generator use macro -DUSE_RND_AS_BEFORE_087 or registerGen(argc, argv, 0)." + " Version 1 has been released on Spring, 2013. Use it to write new generators.")) +#endif +void registerGen(int argc, char *argv[]) { + std::fprintf(stderr, "Use registerGen(argc, argv, 0) or registerGen(argc, argv, 1)." + " The third parameter stands for the random generator version." + " If you are trying to compile old generator use macro -DUSE_RND_AS_BEFORE_087 or registerGen(argc, argv, 0)." + " Version 1 has been released on Spring, 2013. Use it to write new generators.\n\n"); + registerGen(argc, argv, 0); +} +#endif + +void setAppesModeEncoding(std::string appesModeEncoding) { + static const char* const ENCODINGS[] = {"ascii", "utf-7", "utf-8", "utf-16", "utf-16le", "utf-16be", "utf-32", "utf-32le", "utf-32be", "iso-8859-1", +"iso-8859-2", "iso-8859-3", "iso-8859-4", "iso-8859-5", "iso-8859-6", "iso-8859-7", "iso-8859-8", "iso-8859-9", "iso-8859-10", "iso-8859-11", +"iso-8859-13", "iso-8859-14", "iso-8859-15", "iso-8859-16", "windows-1250", "windows-1251", "windows-1252", "windows-1253", "windows-1254", "windows-1255", +"windows-1256", "windows-1257", "windows-1258", "gb2312", "gbk", "gb18030", "big5", "shift-jis", "euc-jp", "euc-kr", +"euc-cn", "euc-tw", "koi8-r", "koi8-u", "tis-620", "ibm437", "ibm850", "ibm852", "ibm855", "ibm857", +"ibm860", "ibm861", "ibm862", "ibm863", "ibm865", "ibm866", "ibm869", "macroman", "maccentraleurope", "maciceland", +"maccroatian", "macromania", "maccyrillic", "macukraine", "macgreek", "macturkish", "machebrew", "macarabic", "macthai", "hz-gb-2312", +"iso-2022-jp", "iso-2022-kr", "iso-2022-cn", "armscii-8", "tscii", "iscii", "viscii", "geostd8", "cp949", "cp874", +"cp1006", "cp775", "cp858", "cp737", "cp853", "cp856", "cp922", "cp1046", "cp1125", "cp1131", +"ptcp154", "koi8-t", "koi8-ru", "mulelao-1", "cp1133", "iso-ir-166", "tcvn", "iso-ir-14", "iso-ir-87", "iso-ir-159"}; + + appesModeEncoding = lowerCase(appesModeEncoding); + bool valid = false; + for (size_t i = 0; i < sizeof(ENCODINGS) / sizeof(ENCODINGS[0]); i++) + if (appesModeEncoding == ENCODINGS[i]) { + valid = true; + break; + } + if (!valid) + quit(_fail, "Unexpected encoding for setAppesModeEncoding(encoding)"); + ::appesModeEncoding = appesModeEncoding; +} + +void registerInteraction(int argc, char *argv[]) { + __testlib_ensuresPreconditions(); + __testlib_set_testset_and_group(argc, argv); + TestlibFinalizeGuard::registered = true; + + testlibMode = _interactor; + __testlib_set_binary(stdin); + + if (argc > 1 && !strcmp("--help", argv[1])) + __testlib_help(); + + if (argc < 3 || argc > 6) { + quit(_fail, std::string("Program must be run with the following arguments: ") + + std::string(" [ [ [<-appes>]]]") + + "\nUse \"--help\" to get help information"); + } + + if (argc <= 4) { + resultName = ""; + appesMode = false; + } + +#ifndef EJUDGE + if (argc == 5) { + resultName = argv[4]; + appesMode = false; + } + + if (argc == 6) { + if (strcmp("-APPES", argv[5]) && strcmp("-appes", argv[5])) { + quit(_fail, std::string("Program must be run with the following arguments: ") + + " [ [<-appes>]]"); + } else { + resultName = argv[4]; + appesMode = true; + } + } +#endif + + inf.init(argv[1], _input); + + tout.open(argv[2], std::ios_base::out); + if (tout.fail() || !tout.is_open()) + quit(_fail, std::string("Can not write to the test-output-file '") + argv[2] + std::string("'")); + + ouf.init(stdin, _output); + + if (argc >= 4) + ans.init(argv[3], _answer); + else + ans.name = "unopened answer stream"; +} + +void registerValidation() { + __testlib_ensuresPreconditions(); + TestlibFinalizeGuard::registered = true; + + testlibMode = _validator; + + __testlib_set_binary(stdin); + __testlib_set_binary(stdout); + __testlib_set_binary(stderr); + + inf.init(stdin, _input); + inf.strict = true; +} + +void registerValidation(int argc, char *argv[]) { + registerValidation(); + __testlib_set_testset_and_group(argc, argv); + + validator.initialize(); + TestlibFinalizeGuard::registered = true; + + std::string comment = "Validator must be run with the following arguments:" + " [--testset testset]" + " [--group group]" + " [--testOverviewLogFileName fileName]" + " [--testMarkupFileName fileName]" + " [--testCase testCase]" + " [--testCaseFileName fileName]" + ; + + for (int i = 1; i < argc; i++) { + if (!strcmp("--testset", argv[i])) { + if (i + 1 < argc && strlen(argv[i + 1]) > 0) + validator.setTestset(argv[++i]); + else + quit(_fail, comment); + } + if (!strcmp("--group", argv[i])) { + if (i + 1 < argc) + validator.setGroup(argv[++i]); + else + quit(_fail, comment); + } + if (!strcmp("--testOverviewLogFileName", argv[i])) { + if (i + 1 < argc) + validator.setTestOverviewLogFileName(argv[++i]); + else + quit(_fail, comment); + } + if (!strcmp("--testMarkupFileName", argv[i])) { + if (i + 1 < argc) + validator.setTestMarkupFileName(argv[++i]); + else + quit(_fail, comment); + } + if (!strcmp("--testCase", argv[i])) { + if (i + 1 < argc) { + long long testCase = stringToLongLong(inf, argv[++i]); + if (testCase < 1 || testCase >= __TESTLIB_MAX_TEST_CASE) + quit(_fail, testlib_format_("Argument testCase should be between 1 and %d, but ", __TESTLIB_MAX_TEST_CASE) + + toString(testCase) + " found"); + validator.setTestCase(int(testCase)); + } else + quit(_fail, comment); + } + if (!strcmp("--testCaseFileName", argv[i])) { + if (i + 1 < argc) { + validator.setTestCaseFileName(argv[++i]); + } else + quit(_fail, comment); + } + } +} + +void addFeature(const std::string &feature) { + if (testlibMode != _validator) + quit(_fail, "Features are supported in validators only."); + validator.addFeature(feature); +} + +void feature(const std::string &feature) { + if (testlibMode != _validator) + quit(_fail, "Features are supported in validators only."); + validator.feature(feature); +} + +class Checker { +private: + bool _initialized; + std::string _testset; + std::string _group; + +public: + Checker() : _initialized(false), _testset("tests"), _group() { + } + + void initialize() { + _initialized = true; + } + + std::string testset() const { + if (!_initialized) + __testlib_fail("Checker should be initialized with registerTestlibCmd(argc, argv) instead of registerTestlibCmd() to support checker.testset()"); + return _testset; + } + + std::string group() const { + if (!_initialized) + __testlib_fail("Checker should be initialized with registerTestlibCmd(argc, argv) instead of registerTestlibCmd() to support checker.group()"); + return _group; + } + + void setTestset(const char *const testset) { + _testset = testset; + } + + void setGroup(const char *const group) { + _group = group; + } +} checker; + +void registerTestlibCmd(int argc, char *argv[]) { + __testlib_ensuresPreconditions(); + __testlib_set_testset_and_group(argc, argv); + TestlibFinalizeGuard::registered = true; + + testlibMode = _checker; + __testlib_set_binary(stdin); + + std::vector args(1, argv[0]); + checker.initialize(); + + for (int i = 1; i < argc; i++) { + if (!strcmp("--testset", argv[i])) { + if (i + 1 < argc && strlen(argv[i + 1]) > 0) + checker.setTestset(argv[++i]); + else + quit(_fail, std::string("Expected testset after --testset command line parameter")); + } else if (!strcmp("--group", argv[i])) { + if (i + 1 < argc) + checker.setGroup(argv[++i]); + else + quit(_fail, std::string("Expected group after --group command line parameter")); + } else + args.push_back(argv[i]); + } + + argc = int(args.size()); + if (argc > 1 && "--help" == args[1]) + __testlib_help(); + + if (argc < 4 || argc > 6) { + quit(_fail, std::string("Program must be run with the following arguments: ") + + std::string("[--testset testset] [--group group] [ [<-appes>]]") + + "\nUse \"--help\" to get help information"); + } + + if (argc == 4) { + resultName = ""; + appesMode = false; + } + +#ifndef EJUDGE + if (argc == 5) { + resultName = args[4]; + appesMode = false; + } + + if (argc == 6) { + if ("-APPES" != args[5] && "-appes" != args[5]) { + quit(_fail, std::string("Program must be run with the following arguments: ") + + " [ [<-appes>]]"); + } else { + resultName = args[4]; + appesMode = true; + } + } +#endif + + inf.init(args[1], _input); + ouf.init(args[2], _output); + ouf.skipBom(); + ans.init(args[3], _answer); +} + +void registerTestlib(int argc, ...) { + if (argc < 3 || argc > 5) + quit(_fail, std::string("Program must be run with the following arguments: ") + + " [ [<-appes>]]"); + + char **argv = new char *[argc + 1]; + + va_list ap; + va_start(ap, argc); + argv[0] = NULL; + for (int i = 0; i < argc; i++) { + argv[i + 1] = va_arg(ap, char*); + } + va_end(ap); + + registerTestlibCmd(argc + 1, argv); + delete[] argv; +} + +static inline void __testlib_ensure(bool cond, const std::string &msg) { + if (!cond) + quit(_fail, msg.c_str()); +} + +#ifdef __GNUC__ +__attribute__((unused)) +#endif +static inline void __testlib_ensure(bool cond, const char *msg) { + if (!cond) + quit(_fail, msg); +} + +#define ensure(cond) __testlib_ensure((cond), "Condition failed: \"" #cond "\"") +#define STRINGIZE_DETAIL(x) (#x) +#define STRINGIZE(x) STRINGIZE_DETAIL((x)) +#define ensure_ext(cond) __testlib_ensure((cond), "Line " STRINGIZE(__LINE__) ": Condition failed: \"" #cond "\"") + +#ifdef __GNUC__ +__attribute__ ((format (printf, 2, 3))) +#endif +inline void ensuref(bool cond, const char *format, ...) { + if (!cond) { + FMT_TO_RESULT(format, format, message); + __testlib_ensure(cond, message); + } +} + +NORETURN static void __testlib_fail(const std::string &message) { + quitf(_fail, "%s", message.c_str()); +} + +#ifdef __GNUC__ +__attribute__ ((format (printf, 1, 2))) +#endif +void setName(const char *format, ...) { + FMT_TO_RESULT(format, format, name); + checkerName = name; +} + +/* + * Do not use random_shuffle, because it will produce different result + * for different C++ compilers. + * + * This implementation uses testlib random_t to produce random numbers, so + * it is stable. + */ +template +void shuffle(_RandomAccessIter __first, _RandomAccessIter __last) { + if (__first == __last) return; + for (_RandomAccessIter __i = __first + 1; __i != __last; ++__i) + std::iter_swap(__i, __first + rnd.next(int(__i - __first) + 1)); +} + + +template +#if defined(__GNUC__) && !defined(__clang__) +__attribute__ ((error("Don't use random_shuffle(), use shuffle() instead"))) +#endif +void random_shuffle(_RandomAccessIter, _RandomAccessIter) { + quitf(_fail, "Don't use random_shuffle(), use shuffle() instead"); +} + +#ifdef __GLIBC__ +# define RAND_THROW_STATEMENT throw() +#else +# define RAND_THROW_STATEMENT +#endif + +#if defined(__GNUC__) && !defined(__clang__) + +__attribute__ ((error("Don't use rand(), use rnd.next() instead"))) +#endif +#ifdef _MSC_VER +# pragma warning( disable : 4273 ) +#endif +int rand() RAND_THROW_STATEMENT +{ + quitf(_fail, "Don't use rand(), use rnd.next() instead"); + + /* This line never runs. */ + //throw "Don't use rand(), use rnd.next() instead"; +} + +#if defined(__GNUC__) && !defined(__clang__) + +__attribute__ ((error("Don't use srand(), you should use " +"'registerGen(argc, argv, 1);' to initialize generator seed " +"by hash code of the command line params. The third parameter " +"is randomGeneratorVersion (currently the latest is 1)."))) +#endif +#ifdef _MSC_VER +# pragma warning( disable : 4273 ) +#endif +void srand(unsigned int seed) RAND_THROW_STATEMENT +{ + quitf(_fail, "Don't use srand(), you should use " + "'registerGen(argc, argv, 1);' to initialize generator seed " + "by hash code of the command line params. The third parameter " + "is randomGeneratorVersion (currently the latest is 1) [ignored seed=%u].", seed); +} + +void startTest(int test) { + const std::string testFileName = vtos(test); + if (NULL == testlib_freopen_(testFileName.c_str(), "wt", stdout)) + __testlib_fail("Unable to write file '" + testFileName + "'"); +} + +#ifdef __GNUC__ +__attribute__((const)) +#endif +inline std::string compress(const std::string &s) { + return __testlib_part(s); +} + +#ifdef __GNUC__ +__attribute__((const)) +#endif +inline std::string englishEnding(int x) { + x %= 100; + if (x / 10 == 1) + return "th"; + if (x % 10 == 1) + return "st"; + if (x % 10 == 2) + return "nd"; + if (x % 10 == 3) + return "rd"; + return "th"; +} + +template +#ifdef __GNUC__ +__attribute__((const)) +#endif +std::string join(_ForwardIterator first, _ForwardIterator last, _Separator separator) { + std::stringstream ss; + bool repeated = false; + for (_ForwardIterator i = first; i != last; i++) { + if (repeated) + ss << separator; + else + repeated = true; + ss << *i; + } + return ss.str(); +} + +template +#ifdef __GNUC__ +__attribute__((const)) +#endif +std::string join(_ForwardIterator first, _ForwardIterator last) { + return join(first, last, ' '); +} + +template +#ifdef __GNUC__ +__attribute__((const)) +#endif +std::string join(const _Collection &collection, _Separator separator) { + return join(collection.begin(), collection.end(), separator); +} + +template +#ifdef __GNUC__ +__attribute__((const)) +#endif +std::string join(const _Collection &collection) { + return join(collection, ' '); +} + +/** + * Splits string s by character separator returning exactly k+1 items, + * where k is the number of separator occurrences. + */ +#ifdef __GNUC__ +__attribute__((const)) +#endif +std::vector split(const std::string &s, char separator) { + std::vector result; + std::string item; + for (size_t i = 0; i < s.length(); i++) + if (s[i] == separator) { + result.push_back(item); + item = ""; + } else + item += s[i]; + result.push_back(item); + return result; +} + +/** + * Splits string s by character separators returning exactly k+1 items, + * where k is the number of separator occurrences. + */ +#ifdef __GNUC__ +__attribute__((const)) +#endif +std::vector split(const std::string &s, const std::string &separators) { + if (separators.empty()) + return std::vector(1, s); + + std::vector isSeparator(256); + for (size_t i = 0; i < separators.size(); i++) + isSeparator[(unsigned char) (separators[i])] = true; + + std::vector result; + std::string item; + for (size_t i = 0; i < s.length(); i++) + if (isSeparator[(unsigned char) (s[i])]) { + result.push_back(item); + item = ""; + } else + item += s[i]; + result.push_back(item); + return result; +} + +/** + * Splits string s by character separator returning non-empty items. + */ +#ifdef __GNUC__ +__attribute__((const)) +#endif +std::vector tokenize(const std::string &s, char separator) { + std::vector result; + std::string item; + for (size_t i = 0; i < s.length(); i++) + if (s[i] == separator) { + if (!item.empty()) + result.push_back(item); + item = ""; + } else + item += s[i]; + if (!item.empty()) + result.push_back(item); + return result; +} + +/** + * Splits string s by character separators returning non-empty items. + */ +#ifdef __GNUC__ +__attribute__((const)) +#endif +std::vector tokenize(const std::string &s, const std::string &separators) { + if (separators.empty()) + return std::vector(1, s); + + std::vector isSeparator(256); + for (size_t i = 0; i < separators.size(); i++) + isSeparator[(unsigned char) (separators[i])] = true; + + std::vector result; + std::string item; + for (size_t i = 0; i < s.length(); i++) + if (isSeparator[(unsigned char) (s[i])]) { + if (!item.empty()) + result.push_back(item); + item = ""; + } else + item += s[i]; + + if (!item.empty()) + result.push_back(item); + + return result; +} + +NORETURN void __testlib_expectedButFound(TResult result, std::string expected, std::string found, const char *prepend) { + std::string message; + if (strlen(prepend) != 0) + message = testlib_format_("%s: expected '%s', but found '%s'", + compress(prepend).c_str(), compress(expected).c_str(), compress(found).c_str()); + else + message = testlib_format_("expected '%s', but found '%s'", + compress(expected).c_str(), compress(found).c_str()); + quit(result, message); +} + +NORETURN void __testlib_expectedButFound(TResult result, double expected, double found, const char *prepend) { + std::string expectedString = removeDoubleTrailingZeroes(testlib_format_("%.12f", expected)); + std::string foundString = removeDoubleTrailingZeroes(testlib_format_("%.12f", found)); + __testlib_expectedButFound(result, expectedString, foundString, prepend); +} + +template +#ifdef __GNUC__ +__attribute__ ((format (printf, 4, 5))) +#endif +NORETURN void expectedButFound(TResult result, T expected, T found, const char *prependFormat = "", ...) { + FMT_TO_RESULT(prependFormat, prependFormat, prepend); + std::string expectedString = vtos(expected); + std::string foundString = vtos(found); + __testlib_expectedButFound(result, expectedString, foundString, prepend.c_str()); +} + +template<> +#ifdef __GNUC__ +__attribute__ ((format (printf, 4, 5))) +#endif +NORETURN void +expectedButFound(TResult result, std::string expected, std::string found, const char *prependFormat, ...) { + FMT_TO_RESULT(prependFormat, prependFormat, prepend); + __testlib_expectedButFound(result, expected, found, prepend.c_str()); +} + +template<> +#ifdef __GNUC__ +__attribute__ ((format (printf, 4, 5))) +#endif +NORETURN void expectedButFound(TResult result, double expected, double found, const char *prependFormat, ...) { + FMT_TO_RESULT(prependFormat, prependFormat, prepend); + std::string expectedString = removeDoubleTrailingZeroes(testlib_format_("%.12f", expected)); + std::string foundString = removeDoubleTrailingZeroes(testlib_format_("%.12f", found)); + __testlib_expectedButFound(result, expectedString, foundString, prepend.c_str()); +} + +template<> +#ifdef __GNUC__ +__attribute__ ((format (printf, 4, 5))) +#endif +NORETURN void +expectedButFound(TResult result, const char *expected, const char *found, const char *prependFormat, + ...) { + FMT_TO_RESULT(prependFormat, prependFormat, prepend); + __testlib_expectedButFound(result, std::string(expected), std::string(found), prepend.c_str()); +} + +template<> +#ifdef __GNUC__ +__attribute__ ((format (printf, 4, 5))) +#endif +NORETURN void expectedButFound(TResult result, float expected, float found, const char *prependFormat, ...) { + FMT_TO_RESULT(prependFormat, prependFormat, prepend); + __testlib_expectedButFound(result, double(expected), double(found), prepend.c_str()); +} + +template<> +#ifdef __GNUC__ +__attribute__ ((format (printf, 4, 5))) +#endif +NORETURN void +expectedButFound(TResult result, long double expected, long double found, const char *prependFormat, ...) { + FMT_TO_RESULT(prependFormat, prependFormat, prepend); + __testlib_expectedButFound(result, double(expected), double(found), prepend.c_str()); +} + +#if __cplusplus > 199711L || defined(_MSC_VER) +template +struct is_iterable { + template + static char test(typename U::iterator *x); + + template + static long test(U *x); + + static const bool value = sizeof(test(0)) == 1; +}; + +template +struct __testlib_enable_if { +}; + +template +struct __testlib_enable_if { + typedef T type; +}; + +template +typename __testlib_enable_if::value, void>::type __testlib_print_one(const T &t) { + std::cout << t; +} + +template +typename __testlib_enable_if::value, void>::type __testlib_print_one(const T &t) { + bool first = true; + for (typename T::const_iterator i = t.begin(); i != t.end(); i++) { + if (first) + first = false; + else + std::cout << " "; + std::cout << *i; + } +} + +template<> +typename __testlib_enable_if::value, void>::type +__testlib_print_one(const std::string &t) { + std::cout << t; +} + +template +void __println_range(A begin, B end) { + bool first = true; + for (B i = B(begin); i != end; i++) { + if (first) + first = false; + else + std::cout << " "; + __testlib_print_one(*i); + } + std::cout << std::endl; +} + +template +struct is_iterator { + static T makeT(); + + typedef void *twoptrs[2]; + + static twoptrs &test(...); + + template + static typename R::iterator_category *test(R); + + template + static void *test(R *); + + static const bool value = sizeof(test(makeT())) == sizeof(void *); +}; + +template +struct is_iterator::value>::type> { + static const bool value = false; +}; + +template +typename __testlib_enable_if::value, void>::type println(const A &a, const B &b) { + __testlib_print_one(a); + std::cout << " "; + __testlib_print_one(b); + std::cout << std::endl; +} + +template +typename __testlib_enable_if::value, void>::type println(const A &a, const B &b) { + __println_range(a, b); +} + +template +void println(const A *a, const A *b) { + __println_range(a, b); +} + +template<> +void println(const char *a, const char *b) { + __testlib_print_one(a); + std::cout << " "; + __testlib_print_one(b); + std::cout << std::endl; +} + +template +void println(const T &x) { + __testlib_print_one(x); + std::cout << std::endl; +} + +template +void println(const A &a, const B &b, const C &c) { + __testlib_print_one(a); + std::cout << " "; + __testlib_print_one(b); + std::cout << " "; + __testlib_print_one(c); + std::cout << std::endl; +} + +template +void println(const A &a, const B &b, const C &c, const D &d) { + __testlib_print_one(a); + std::cout << " "; + __testlib_print_one(b); + std::cout << " "; + __testlib_print_one(c); + std::cout << " "; + __testlib_print_one(d); + std::cout << std::endl; +} + +template +void println(const A &a, const B &b, const C &c, const D &d, const E &e) { + __testlib_print_one(a); + std::cout << " "; + __testlib_print_one(b); + std::cout << " "; + __testlib_print_one(c); + std::cout << " "; + __testlib_print_one(d); + std::cout << " "; + __testlib_print_one(e); + std::cout << std::endl; +} + +template +void println(const A &a, const B &b, const C &c, const D &d, const E &e, const F &f) { + __testlib_print_one(a); + std::cout << " "; + __testlib_print_one(b); + std::cout << " "; + __testlib_print_one(c); + std::cout << " "; + __testlib_print_one(d); + std::cout << " "; + __testlib_print_one(e); + std::cout << " "; + __testlib_print_one(f); + std::cout << std::endl; +} + +template +void println(const A &a, const B &b, const C &c, const D &d, const E &e, const F &f, const G &g) { + __testlib_print_one(a); + std::cout << " "; + __testlib_print_one(b); + std::cout << " "; + __testlib_print_one(c); + std::cout << " "; + __testlib_print_one(d); + std::cout << " "; + __testlib_print_one(e); + std::cout << " "; + __testlib_print_one(f); + std::cout << " "; + __testlib_print_one(g); + std::cout << std::endl; +} + +/* opts */ + +/** + * A struct for a singular testlib opt, containing the raw string value, + * and a boolean value for marking whether the opt is used. + */ +struct TestlibOpt { + std::string value; + bool used; + + TestlibOpt() : value(), used(false) {} +}; + +/** + * Get the type of opt based on the number of `-` at the beginning and the + * _validity_ of the key name. + * + * A valid key name must start with an alphabetical character. + * + * Returns: 1 if s has one `-` at the beginning, that is, "-keyName". + * 2 if s has two `-` at the beginning, that is, "--keyName". + * 0 otherwise. That is, if s has no `-` at the beginning, or has more + * than 2 at the beginning ("---keyName", "----keyName", ...), or the + * keyName is invalid (the first character is not an alphabetical + * character). + */ +size_t getOptType(char *s) { + if (!s || strlen(s) <= 1) + return 0; + + if (s[0] == '-') { + if (isalpha(s[1])) + return 1; + else if (s[1] == '-') + return isalpha(s[2]) ? 2 : 0; + } + + return 0; +} + +/** + * Parse the opt at a given index, and put it into the opts maps. + * + * An opt can has the following form: + * 1) -keyName=value or --keyName=value (ex. -n=10 --test-count=20) + * 2) -keyName value or --keyName value (ex. -n 10 --test-count 20) + * 3) -kNumval or --kNumval (ex. -n10 --t20) + * 4) -boolProperty or --boolProperty (ex. -sorted --tree-only) + * + * Only the second form consumes 2 arguments. The other consumes only 1 + * argument. + * + * In the third form, the key is a single character, and after the key is the + * value. The value _should_ be a number. + * + * In the forth form, the value is true. + * + * Params: + * - argc and argv: the number of command line arguments and the command line + * arguments themselves. + * - index: the starting index of the opts. + * - opts: the map containing the resulting opt. + * + * Returns: the number of consumed arguments to parse the opt. + * 0 if there is no arguments to parse. + * + * Algorithm details: + * TODO. Please refer to the implementation to see how the code handles the 3rd and 4th forms separately. + */ +size_t parseOpt(size_t argc, char *argv[], size_t index, std::map &opts) { + if (index >= argc) + return 0; + + size_t type = getOptType(argv[index]), inc = 1; + if (type > 0) { + std::string key(argv[index] + type), val; + size_t sep = key.find('='); + if (sep != std::string::npos) { + val = key.substr(sep + 1); + key = key.substr(0, sep); + } else { + if (index + 1 < argc && getOptType(argv[index + 1]) == 0) { + val = argv[index + 1]; + inc = 2; + } else { + if (key.length() > 1 && isdigit(key[1])) { + val = key.substr(1); + key = key.substr(0, 1); + } else { + val = "true"; + } + } + } + opts[key].value = val; + } else { + return inc; + } + + return inc; +} + +/** + * Global list containing all the arguments in the order given in the command line. + */ +std::vector __testlib_argv; + +/** + * Global dictionary containing all the parsed opts. + */ +std::map __testlib_opts; + +/** + * Whether automatic no unused opts ensurement should be done. This flag will + * be turned on when `has_opt` or `opt(key, default_value)` is called. + * + * The automatic ensurement can be suppressed when + * __testlib_ensureNoUnusedOptsSuppressed is true. + */ +bool __testlib_ensureNoUnusedOptsFlag = false; + +/** + * Suppress no unused opts automatic ensurement. Can be set to true with + * `suppressEnsureNoUnusedOpts()`. + */ +bool __testlib_ensureNoUnusedOptsSuppressed = false; + +/** + * Parse command line arguments into opts. + * The results are stored into __testlib_argv and __testlib_opts. + */ +void prepareOpts(int argc, char *argv[]) { + if (argc <= 0) + __testlib_fail("Opts: expected argc>=0 but found " + toString(argc)); + size_t n = static_cast(argc); // NOLINT(hicpp-use-auto,modernize-use-auto) + __testlib_opts = std::map(); + for (size_t index = 1; index < n; index += parseOpt(n, argv, index, __testlib_opts)); + __testlib_argv = std::vector(n); + for (size_t index = 0; index < n; index++) + __testlib_argv[index] = argv[index]; +} + +/** + * An utility function to get the argument with a given index. This function + * also print a readable message when no arguments are found. + */ +std::string __testlib_indexToArgv(int index) { + if (index < 0 || index >= int(__testlib_argv.size())) + __testlib_fail("Opts: index '" + toString(index) + "' is out of range [0," + + toString(__testlib_argv.size()) + ")"); + return __testlib_argv[size_t(index)]; +} + +/** + * An utility function to get the opt with a given key . This function + * also print a readable message when no opts are found. + */ +std::string __testlib_keyToOpts(const std::string &key) { + auto it = __testlib_opts.find(key); + if (it == __testlib_opts.end()) + __testlib_fail("Opts: unknown key '" + compress(key) + "'"); + it->second.used = true; + return it->second.value; +} + +template +T optValueToIntegral(const std::string &s, bool nonnegative); + +long double optValueToLongDouble(const std::string &s); + +std::string parseExponentialOptValue(const std::string &s) { + size_t pos = std::string::npos; + for (size_t i = 0; i < s.length(); i++) + if (s[i] == 'e' || s[i] == 'E') { + if (pos != std::string::npos) + __testlib_fail("Opts: expected typical exponential notation but '" + compress(s) + "' found"); + pos = i; + } + if (pos == std::string::npos) + return s; + std::string e = s.substr(pos + 1); + if (!e.empty() && e[0] == '+') + e = e.substr(1); + if (e.empty()) + __testlib_fail("Opts: expected typical exponential notation but '" + compress(s) + "' found"); + if (e.length() > 20) + __testlib_fail("Opts: expected typical exponential notation but '" + compress(s) + "' found"); + int ne = optValueToIntegral(e, false); + std::string num = s.substr(0, pos); + if (num.length() > 20) + __testlib_fail("Opts: expected typical exponential notation but '" + compress(s) + "' found"); + if (!num.empty() && num[0] == '+') + num = num.substr(1); + optValueToLongDouble(num); + bool minus = false; + if (num[0] == '-') { + minus = true; + num = num.substr(1); + } + for (int i = 0; i < +ne; i++) { + size_t sep = num.find('.'); + if (sep == std::string::npos) + num += '0'; + else { + if (sep + 1 == num.length()) + num[sep] = '0'; + else + std::swap(num[sep], num[sep + 1]); + } + } + for (int i = 0; i < -ne; i++) { + size_t sep = num.find('.'); + if (sep == std::string::npos) + num.insert(num.begin() + int(num.length()) - 1, '.'); + else { + if (sep == 0) + num.insert(num.begin() + 1, '0'); + else + std::swap(num[sep - 1], num[sep]); + } + } + while (!num.empty() && num[0] == '0') + num = num.substr(1); + while (num.find('.') != std::string::npos && num.back() == '0') + num = num.substr(0, num.length() - 1); + if (!num.empty() && num.back() == '.') + num = num.substr(0, num.length() - 1); + if ((!num.empty() && num[0] == '.') || num.empty()) + num.insert(num.begin(), '0'); + return (minus ? "-" : "") + num; +} + +template +T optValueToIntegral(const std::string &s_, bool nonnegative) { + std::string s(parseExponentialOptValue(s_)); + if (s.empty()) + __testlib_fail("Opts: expected integer but '" + compress(s_) + "' found"); + T value = 0; + long double about = 0.0; + signed char sign = +1; + size_t pos = 0; + if (s[pos] == '-') { + if (nonnegative) + __testlib_fail("Opts: expected non-negative integer but '" + compress(s_) + "' found"); + sign = -1; + pos++; + } + for (size_t i = pos; i < s.length(); i++) { + if (s[i] < '0' || s[i] > '9') + __testlib_fail("Opts: expected integer but '" + compress(s_) + "' found"); + value = T(value * 10 + s[i] - '0'); + about = about * 10 + s[i] - '0'; + } + value *= sign; + about *= sign; + if (fabsl(value - about) > 0.1) + __testlib_fail("Opts: integer overflow: expected integer but '" + compress(s_) + "' found"); + return value; +} + +long double optValueToLongDouble(const std::string &s_) { + std::string s(parseExponentialOptValue(s_)); + if (s.empty()) + __testlib_fail("Opts: expected float number but '" + compress(s_) + "' found"); + long double value = 0.0; + signed char sign = +1; + size_t pos = 0; + if (s[pos] == '-') { + sign = -1; + pos++; + } + bool period = false; + long double mul = 1.0; + for (size_t i = pos; i < s.length(); i++) { + if (s[i] == '.') { + if (period) + __testlib_fail("Opts: expected float number but '" + compress(s_) + "' found"); + else { + period = true; + continue; + } + } + if (period) + mul *= 10.0; + if (s[i] < '0' || s[i] > '9') + __testlib_fail("Opts: expected float number but '" + compress(s_) + "' found"); + if (period) + value += (s[i] - '0') / mul; + else + value = value * 10 + s[i] - '0'; + } + value *= sign; + return value; +} + +/** + * Return true if there is an opt with a given key. + * + * By calling this function, automatic ensurement for no unused opts will be + * done when the program is finalized. Call suppressEnsureNoUnusedOpts() to + * turn it off. + */ +bool has_opt(const std::string &key) { + __testlib_ensureNoUnusedOptsFlag = true; + return __testlib_opts.count(key) != 0; +} + +/* About the following part for opt with 2 and 3 arguments. + * + * To parse the argv/opts correctly for a give type (integer, floating point or + * string), some meta programming must be done to determine the type of + * the type, and use the correct parsing function accordingly. + * + * The pseudo algorithm for determining the type of T and parse it accordingly + * is as follows: + * + * if (T is integral type) { + * if (T is unsigned) { + * parse the argv/opt as an **unsigned integer** of type T. + * } else { + * parse the argv/opt as an **signed integer** of type T. + * } else { + * if (T is floating point type) { + * parse the argv/opt as an **floating point** of type T. + * } else { + * // T should be std::string + * just the raw content of the argv/opts. + * } + * } + * + * To help with meta programming, some `opt` function with 2 or 3 arguments are + * defined. + * + * Opt with 3 arguments: T opt(true/false is_integral, true/false is_unsigned, index/key) + * + * + The first argument is for determining whether the type T is an integral + * type. That is, the result of std::is_integral() should be passed to + * this argument. When false, the type _should_ be either floating point or a + * std::string. + * + * + The second argument is for determining whether the signedness of the type + * T (if it is unsigned or signed). That is, the result of + * std::is_unsigned() should be passed to this argument. This argument can + * be ignored if the first one is false, because it only applies to integer. + * + * Opt with 2 arguments: T opt(true/false is_floating_point, index/key) + * + The first argument is for determining whether the type T is a floating + * point type. That is, the result of std::is_floating_point() should be + * passed to this argument. When false, the type _should_ be a std::string. + */ + +template +T opt(std::false_type is_floating_point, int index); + +template<> +std::string opt(std::false_type /*is_floating_point*/, int index) { + return __testlib_indexToArgv(index); +} + +template +T opt(std::true_type /*is_floating_point*/, int index) { + return T(optValueToLongDouble(__testlib_indexToArgv(index))); +} + +template +T opt(std::false_type /*is_integral*/, U /*is_unsigned*/, int index) { + return opt(std::is_floating_point(), index); +} + +template +T opt(std::true_type /*is_integral*/, std::false_type /*is_unsigned*/, int index) { + return optValueToIntegral(__testlib_indexToArgv(index), false); +} + +template +T opt(std::true_type /*is_integral*/, std::true_type /*is_unsigned*/, int index) { + return optValueToIntegral(__testlib_indexToArgv(index), true); +} + +template<> +bool opt(std::true_type /*is_integral*/, std::true_type /*is_unsigned*/, int index) { + std::string value = __testlib_indexToArgv(index); + if (value == "true" || value == "1") + return true; + if (value == "false" || value == "0") + return false; + __testlib_fail("Opts: opt by index '" + toString(index) + "': expected bool true/false or 0/1 but '" + + compress(value) + "' found"); +} + +/** + * Return the parsed argv by a given index. + */ +template +T opt(int index) { + return opt(std::is_integral(), std::is_unsigned(), index); +} + +/** + * Return the raw string value of an argv by a given index. + */ +std::string opt(int index) { + return opt(index); +} + +/** + * Return the parsed argv by a given index. If the index is bigger than + * the number of argv, return the given default_value. + */ +template +T opt(int index, const T &default_value) { + if (index >= int(__testlib_argv.size())) { + return default_value; + } + return opt(index); +} + +/** + * Return the raw string value of an argv by a given index. If the index is + * bigger than the number of argv, return the given default_value. + */ +std::string opt(int index, const std::string &default_value) { + return opt(index, default_value); +} + +template +T opt(std::false_type is_floating_point, const std::string &key); + +template<> +std::string opt(std::false_type /*is_floating_point*/, const std::string &key) { + return __testlib_keyToOpts(key); +} + +template +T opt(std::true_type /*is_integral*/, const std::string &key) { + return T(optValueToLongDouble(__testlib_keyToOpts(key))); +} + +template +T opt(std::false_type /*is_integral*/, U, const std::string &key) { + return opt(std::is_floating_point(), key); +} + +template +T opt(std::true_type /*is_integral*/, std::false_type /*is_unsigned*/, const std::string &key) { + return optValueToIntegral(__testlib_keyToOpts(key), false); +} + +template +T opt(std::true_type /*is_integral*/, std::true_type /*is_unsigned*/, const std::string &key) { + return optValueToIntegral(__testlib_keyToOpts(key), true); +} + +template<> +bool opt(std::true_type /*is_integral*/, std::true_type /*is_unsigned*/, const std::string &key) { + if (!has_opt(key)) + return false; + std::string value = __testlib_keyToOpts(key); + if (value == "true" || value == "1") + return true; + if (value == "false" || value == "0") + return false; + __testlib_fail("Opts: key '" + compress(key) + "': expected bool true/false or 0/1 but '" + + compress(value) + "' found"); +} + +/** + * Return the parsed opt by a given key. + */ +template +T opt(const std::string &key) { + return opt(std::is_integral(), std::is_unsigned(), key); +} + +/** + * Return the raw string value of an opt by a given key + */ +std::string opt(const std::string &key) { + return opt(key); +} + +/* Scorer started. */ + +enum TestResultVerdict { + SKIPPED, + OK, + WRONG_ANSWER, + RUNTIME_ERROR, + TIME_LIMIT_EXCEEDED, + IDLENESS_LIMIT_EXCEEDED, + MEMORY_LIMIT_EXCEEDED, + COMPILATION_ERROR, + CRASHED, + FAILED +}; + +std::string serializeVerdict(TestResultVerdict verdict) { + switch (verdict) { + case SKIPPED: return "SKIPPED"; + case OK: return "OK"; + case WRONG_ANSWER: return "WRONG_ANSWER"; + case RUNTIME_ERROR: return "RUNTIME_ERROR"; + case TIME_LIMIT_EXCEEDED: return "TIME_LIMIT_EXCEEDED"; + case IDLENESS_LIMIT_EXCEEDED: return "IDLENESS_LIMIT_EXCEEDED"; + case MEMORY_LIMIT_EXCEEDED: return "MEMORY_LIMIT_EXCEEDED"; + case COMPILATION_ERROR: return "COMPILATION_ERROR"; + case CRASHED: return "CRASHED"; + case FAILED: return "FAILED"; + } + throw "Unexpected verdict"; +} + +TestResultVerdict deserializeTestResultVerdict(std::string s) { + if (s == "SKIPPED") + return SKIPPED; + else if (s == "OK") + return OK; + else if (s == "WRONG_ANSWER") + return WRONG_ANSWER; + else if (s == "RUNTIME_ERROR") + return RUNTIME_ERROR; + else if (s == "TIME_LIMIT_EXCEEDED") + return TIME_LIMIT_EXCEEDED; + else if (s == "IDLENESS_LIMIT_EXCEEDED") + return IDLENESS_LIMIT_EXCEEDED; + else if (s == "MEMORY_LIMIT_EXCEEDED") + return MEMORY_LIMIT_EXCEEDED; + else if (s == "COMPILATION_ERROR") + return COMPILATION_ERROR; + else if (s == "CRASHED") + return CRASHED; + else if (s == "FAILED") + return FAILED; + ensuref(false, "Unexpected serialized TestResultVerdict"); + // No return actually. + return FAILED; +} + +struct TestResult { + int testIndex; + std::string testset; + std::string group; + TestResultVerdict verdict; + double points; + long long timeConsumed; + long long memoryConsumed; + std::string input; + std::string output; + std::string answer; + int exitCode; + std::string checkerComment; +}; + +std::string serializePoints(double points) { + if (std::isnan(points)) + return ""; + else { + char c[64]; + snprintf(c, 64, "%.03lf", points); + return c; + } +} + +double deserializePoints(std::string s) { + if (s.empty()) + return std::numeric_limits::quiet_NaN(); + else { + double result; +#ifdef _MSC_VER + ensuref(sscanf_s(s.c_str(), "%lf", &result) == 1, "Invalid serialized points"); +#else + ensuref(std::sscanf(s.c_str(), "%lf", &result) == 1, "Invalid serialized points"); +#endif + return result; + } +} + +std::string escapeTestResultString(std::string s) { + std::string result; + for (size_t i = 0; i < s.length(); i++) { + if (s[i] == '\r') + continue; + if (s[i] == '\n') { + result += "\\n"; + continue; + } + if (s[i] == '\\' || s[i] == ';') + result += '\\'; + result += s[i]; + } + return result; +} + +std::string unescapeTestResultString(std::string s) { + std::string result; + for (size_t i = 0; i < s.length(); i++) { + if (s[i] == '\\' && i + 1 < s.length()) { + if (s[i + 1] == 'n') { + result += '\n'; + i++; + continue; + } else if (s[i + 1] == ';' || s[i + 1] == '\\') { + result += s[i + 1]; + i++; + continue; + } + } + result += s[i]; + } + return result; +} + +std::string serializeTestResult(TestResult tr) { + std::string result; + result += std::to_string(tr.testIndex); + result += ";"; + result += escapeTestResultString(tr.testset); + result += ";"; + result += escapeTestResultString(tr.group); + result += ";"; + result += serializeVerdict(tr.verdict); + result += ";"; + result += serializePoints(tr.points); + result += ";"; + result += std::to_string(tr.timeConsumed); + result += ";"; + result += std::to_string(tr.memoryConsumed); + result += ";"; + result += escapeTestResultString(tr.input); + result += ";"; + result += escapeTestResultString(tr.output); + result += ";"; + result += escapeTestResultString(tr.answer); + result += ";"; + result += std::to_string(tr.exitCode); + result += ";"; + result += escapeTestResultString(tr.checkerComment); + return result; +} + +TestResult deserializeTestResult(std::string s) { + std::vector items; + std::string t; + for (size_t i = 0; i < s.length(); i++) { + if (s[i] == '\\') { + t += s[i]; + if (i + 1 < s.length()) + t += s[i + 1]; + i++; + continue; + } else { + if (s[i] == ';') { + items.push_back(t); + t = ""; + } else + t += s[i]; + } + } + items.push_back(t); + + ensuref(items.size() == 12, "Invalid TestResult serialization: expected exactly 12 items"); + + TestResult tr; + size_t pos = 0; + tr.testIndex = stoi(items[pos++]); + tr.testset = unescapeTestResultString(items[pos++]); + tr.group = unescapeTestResultString(items[pos++]); + tr.verdict = deserializeTestResultVerdict(items[pos++]); + tr.points = deserializePoints(items[pos++]); + tr.timeConsumed = stoll(items[pos++]); + tr.memoryConsumed = stoll(items[pos++]); + tr.input = unescapeTestResultString(items[pos++]); + tr.output = unescapeTestResultString(items[pos++]); + tr.answer = unescapeTestResultString(items[pos++]); + tr.exitCode = stoi(items[pos++]); + tr.checkerComment = unescapeTestResultString(items[pos++]); + + return tr; +} + +std::vector readTestResults(std::string fileName) { + std::ifstream stream; + stream.open(fileName.c_str(), std::ios::in); + ensuref(stream.is_open(), "Can't read test results file '%s'", fileName.c_str()); + std::vector result; + std::string line; + while (getline(stream, line)) + if (!line.empty()) + result.push_back(deserializeTestResult(line)); + stream.close(); + return result; +} + +std::function)> __testlib_scorer; + +struct TestlibScorerGuard { + ~TestlibScorerGuard() { + if (testlibMode == _scorer) { + std::vector testResults; + while (!inf.eof()) { + std::string line = inf.readLine(); + if (!line.empty()) + testResults.push_back(deserializeTestResult(line)); + } + inf.readEof(); + printf("%.3f\n", __testlib_scorer(testResults)); + } + } +} __testlib_scorer_guard; + +void registerScorer(int argc, char *argv[], std::function)> scorer) { + /* Suppress unused. */ + (void)(argc), (void)(argv); + + __testlib_ensuresPreconditions(); + + testlibMode = _scorer; + __testlib_set_binary(stdin); + + inf.init(stdin, _input); + inf.strict = false; + + __testlib_scorer = scorer; +} + +/* Scorer ended. */ + +/** + * Return the parsed opt by a given key. If no opts with the given key are + * found, return the given default_value. + * + * By calling this function, automatic ensurement for no unused opts will be + * done when the program is finalized. Call suppressEnsureNoUnusedOpts() to + * turn it off. + */ +template +T opt(const std::string &key, const T &default_value) { + if (!has_opt(key)) { + return default_value; + } + return opt(key); +} + +/** + * Return the raw string value of an opt by a given key. If no opts with the + * given key are found, return the given default_value. + * + * By calling this function, automatic ensurement for no unused opts will be + * done when the program is finalized. Call suppressEnsureNoUnusedOpts() to + * turn it off. + */ +std::string opt(const std::string &key, const std::string &default_value) { + return opt(key, default_value); +} + +/** + * Check if all opts are used. If not, __testlib_fail is called. + * Should be used after calling all opt() function calls. + * + * This function is useful when opt() with default_value for checking typos + * in the opt's key. + */ +void ensureNoUnusedOpts() { + for (const auto &opt: __testlib_opts) { + if (!opt.second.used) { + __testlib_fail(testlib_format_("Opts: unused key '%s'", compress(opt.first).c_str())); + } + } +} + +void suppressEnsureNoUnusedOpts() { + __testlib_ensureNoUnusedOptsSuppressed = true; +} + +void TestlibFinalizeGuard::autoEnsureNoUnusedOpts() { + if (__testlib_ensureNoUnusedOptsFlag && !__testlib_ensureNoUnusedOptsSuppressed) { + ensureNoUnusedOpts(); + } +} + +TestlibFinalizeGuard testlibFinalizeGuard; +#endif + +#ifdef __GNUC__ +__attribute__ ((format (printf, 1, 2))) +#endif +std::string testlib_format_(const char *fmt, ...) { + FMT_TO_RESULT(fmt, fmt, result); + return result; +} + +std::string testlib_format_(const std::string fmt, ...) { + FMT_TO_RESULT(fmt, fmt.c_str(), result); + return result; +} + +#if (__cplusplus >= 202002L && __has_include()) || __cpp_lib_format +template +std::string format(const char* fmt, Args&&... args) { + size_t size = size_t(std::snprintf(nullptr, 0, fmt, args...) + 1); + std::vector buffer(size); + std::snprintf(buffer.data(), size, fmt, args...); + return std::string(buffer.data()); +} + +template +std::string format(const std::string fmt, Args&&... args) { + size_t size = size_t(std::snprintf(nullptr, 0, fmt.c_str(), args...) + 1); + std::vector buffer(size); + std::snprintf(buffer.data(), size, fmt.c_str(), args...); + return std::string(buffer.data()); +} +#else +#ifdef __GNUC__ +__attribute__ ((format (printf, 1, 2))) +#endif +std::string format(const char *fmt, ...) { + FMT_TO_RESULT(fmt, fmt, result); + return result; +} + +std::string format(const std::string fmt, ...) { + FMT_TO_RESULT(fmt, fmt.c_str(), result); + return result; +} +#endif + +#endif diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/code/1.png" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/code/1.png" new file mode 100644 index 0000000000000000000000000000000000000000..0b0b5b21ad386507fc8cfa4895212a6b2164e981 Binary files /dev/null and "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/code/1.png" differ diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/code/code.md" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/code/code.md" new file mode 100644 index 0000000000000000000000000000000000000000..772bf6a81ee14e85a216596017ce93d1528a3155 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/code/code.md" @@ -0,0 +1,1262 @@ +#
高带宽内存(HBM)管理的自适应签名驱动虚拟段调度算法
+ +**文件**: `code.py` + +--- + +## 目录 + +1. [算法概述](#1-项目概述) + - 1.1 [问题背景](#11-问题背景) + - 1.2 [核心思想](#12-核心思想) + - 1.3 [设计动机](#13-设计动机) + - 1.3.1 [为什么需要虚拟段抽象?](#131-为什么需要虚拟段抽象) + - 1.3.2 [为什么需要三级混合驱逐策略?](#132-为什么需要三级混合驱逐策略) + - 1.3.3 [为什么需要松弛感知的前瞻预取?](#133-为什么需要松弛感知的前瞻预取) + - 1.3.4 [为什么需要工作负载自适应策略选择?](#134-为什么需要工作负载自适应策略选择) +2. [整体架构](#2-整体架构) +3. [核心组件详解](#3-核心组件详解) +4. [算法流程](#4-算法流程) +5. [使用说明及示例展示](#5-使用说明及示例展示) + +--- + +## 1. 算法概述 + +### 1.1 问题背景 + +本系统解决的是**高带宽内存(HBM)的智能调度问题**: +- **目标**: 最小化任务完成时间 +- **约束**: + - HBM 容量有限(M 字节) + - 同一段内存(的子集)可能被反复访问,并且不同的内存段之间可能存在交集 + - 内存访问时间满足$start_0 \le start_1 \le ... \le start_{n-1}$,这也意味着不同段内存的访问可能是同时发生的 + - 内存访问序列的每个四元组都被认为是一个计算过程,这个过程可以和内存卸载/加载过程并行,并且不影响彼此的执行时间 + - 内存卸载和加载过程,不能并行 + - 在初始状态时(第一条内存序列执行前),内存都是空闲的,所有的数据都未加载到内存 + - 访问序列中,$start_i$相等的条目应视为**同一计算任务**发出的请求。**同一计算任务的开始时间和持续时间均相同,且同一计算任务的内存访问可以并行完成,而不同计算任务之间只能串行完成。同一计算任务的内存访问操作必须同时开始,同时结束** + +- **挑战**: + - 瞬时的最大内存占用,不能超过内存的最大容量 + - 对于任意一段内存,在被访问的时间段内,必须固定在内存中(即这段内存不能处于被卸载的状态) + - I/O 操作开销巨大(每字节 40 个周期) + - 需要平衡内存利用率和 I/O 效率,寻求高效的内存加载卸载与计算的重叠 + +### 1.2 核心思想 + +#### 基础概念定义 + +首先明确几个核心术语: + +**1. 任务组** +- **定义**:具有**相同开始时间** ($start\_time$) 的内存访问请求集合 +- **构成**:一个任务组包含一个或多个内存访问请求,这些请求同时发起,代表一个计算任务 +- **关键属性**: + - `start_time`:任务组的开始时间 + - `duration`:计算持续时间(所有请求共享) + - `request_ids`:包含的请求 ID 列表 +- **示例**: + ``` + 时间 T0: 请求 [R0: addr=0, size=100], [R1: addr=200, size=50] + → 任务组 G0 (start_time=T0, requests=[R0, R1]) + + 时间 T5: 请求 [R2: addr=50, size=80] + → 任务组 G1 (start_time=T5, requests=[R2]) + ``` + +**2. 物理段** +- **定义**:地址空间中**不重叠的连续内存区域**,是内存管理的最小物理单元 +- **生成方法**:**地址空间离散化** (Address Space Discretization) + 1. 收集所有内存访问请求的边界点(起始地址和结束地址) + 2. 对边界点排序:$\{0, p_1, p_2, ..., p_k, L\}$ + 3. 相邻边界点间的区间即为物理段:$[p_i, p_{i+1})$ +- **示例**: + ``` + 地址空间大小 L = 1000 + 请求: R0=[0, 100), R1=[50, 150), R2=[200, 300) + + 边界点: {0, 50, 100, 150, 200, 300, 1000} + 物理段: + S0: [0, 50) ← 仅被 R0 访问 + S1: [50, 100) ← 被 R0, R1 共同访问 + S2: [100, 150) ← 仅被 R1 访问 + S3: [150, 200) ← 未被访问(空闲) + S4: [200, 300) ← 仅被 R2 访问 + S5: [300, 1000) ← 未被访问(空闲) + ``` + +**3. 访问模式** +- **定义**:描述**哪些任务组访问某个物理段**的时间序列 +- **表示方法**:任务组索引的有序列表 +- **示例**: + ``` + 物理段 S1: + - 任务组 G0 需要访问(包含请求覆盖 S1) + - 任务组 G2 需要访问 + - 任务组 G5 需要访问 + → 访问模式 = [0, 2, 5] + ``` + +**4. 签名** +- **定义**:访问模式的**元组表示**,用于标识相同访问模式的物理段 +- **形式化**:$\text{signature}(S) = (g_1, g_2, ..., g_k)$,其中 $g_i$ 是访问物理段 $S$ 的任务组索引 +- **作用**:作为虚拟段合并的**等价类标识** +- **示例**: + ``` + 物理段 S0: signature = (0, 2, 5) ← 被任务组 0, 2, 5 访问 + 物理段 S1: signature = (0, 2, 5) ← 相同签名! + 物理段 S2: signature = (1, 3) ← 不同签名 + + → S0 和 S1 可以合并为一个虚拟段(访问模式相同) + ``` + +**5. 虚拟段** +- **定义**:具有**相同签名**的物理段集合的逻辑抽象 +- **目的**:将多个物理段作为一个整体进行调度,实现批量 I/O +- **映射关系**: + ``` + 虚拟段 V0 → {物理段 S0, 物理段 S3, 物理段 S7} + (signature均为 (0, 2, 5)) + + 虚拟段 V1 → {物理段 S1, 物理段 S4} + (signature均为 (1, 3)) + ``` + +--- + +#### 创新点 + +本算法的核心创新在于**签名驱动的虚拟段抽象 + 工作负载自适应策略**,构建了一个多层次、智能化的内存调度框架: + +#### 1.2.1 签名驱动的虚拟段抽象 + +**定义**:将具有相同访问模式(签名)的物理段合并为虚拟段,作为统一的调度单元。 + +**机制**: +- **访问签名**:$\text{signature}(S) = (g_1, g_2, ..., g_k)$ - 访问物理段 $S$ 的任务组索引元组 +- **签名等价合并**:相同签名的物理段被抽象为同一虚拟段 +- **双层地址映射**:虚拟段(逻辑层)→ 物理段集合(物理层) + +**关键参数**:`enable_merge` (bool) - 控制是否启用虚拟段合并 + + +#### 1.2.2 三级混合驱逐策略 + +基于段的"未来价值"建立三级优先级系统: +- **优先级 1: 垃圾段** - 不再被访问的段,立即驱逐 +- **优先级 2: 求解器候选** - 单次驱逐可完全满足内存需求 +- **优先级 3: 贪心候选** - 基于距离-大小混合评分 + +**评分公式**:`score = distance / sqrt(size) - 0.5 × lock_penalty` +- 远期段优先:Belady最优替换算法的启发式实现 +- 小段偏好:最小化重新加载成本 + +#### 1.2.3 松弛感知的前瞻预取 + +NPU 计算期间存在 I/O 空闲时间窗口,应智能利用。 + +**机制**: +1. **松弛分数**:量化预取的紧迫性,考虑 I/O 密度因子 +2. **双重排序策略**:Late队列(紧急段,FCFS)+ Early队列(非紧急段,SJF) +3. **智能交换**:内存不足时找到"可替换段"进行交换 + +**优势**:计算-I/O 重叠最大化 + +**动态预取窗口(lookahead)调整**: + +预取窗口大小不是固定值,而是综合考虑三个维度动态调整: + +```python +lookahead = base × (任务组比例) × (内存因子) × (IO因子) +``` + +**调整维度**: +1. **任务组规模**:小规模(≤100)限制在20-30%,避免过度预取远期数据 +2. **内存压力**:高压(≥0.95)降低20%,减少内存占用 +3. **IO密度**:密集(>1.2)提升30%,充分利用I/O空闲时间 + +**设计理由**: +- 固定窗口在小规模场景会预取过多远期段,导致"预取→驱逐→重载"的浪费 +- 结合工作负载特征动态调整,在不同场景下找到预取深度的最优点 + +#### 1.2.4 工作负载自适应策略选择 + +不同的输入数据特征(瓶颈点)完全不同,不存在一种单一的算法能同时在所有场景下达到最优。因此需要根据瓶颈的不同选择相应的策略 + +**自适应机制**:基于三大特征指标动态选择策略: +- **max\_p** (内存压力) ≥ 0.95 → 禁用合并,精确控制 +- **io\_density** (IO密度) > 1.05 → 深度预取 +- **avg\_seg\_size** (平均段大小) > 12.5%M → 启用批量合并 + +**创新**: +- 滑动窗口 I/O 密度检测局部峰值 +- 相对段大小阈值(12.5%M)适应不同内存规模 +- 四级决策树寻找帕累托最优点 + +--- + +## 1.3 设计动机 + +本节深入剖析算法各个设计选择背后的动机,解答"为什么这样设计"。 + +### 1.3.1 为什么需要虚拟段抽象? + +虚拟段抽象解决了物理段调度的难题:**计算性能**、**决策质量**。 + +--- + +#### 难题一:计算性能瓶颈 + +**问题**:地址空间离散化后,逻辑对象被切割成大量微小物理段。 + +``` +典型场景: + - 物理段数:20,000(碎片化后) + - 任务组数:10,000 + - 调度循环复杂度:O(G × P²) + - Python 循环次数:~10⁸ → 直接超时(TLE) +``` + +**虚拟段解决方案**: +通过访问签名合并,使复杂度降低:O(G × V²),其中 V ≪ P,V为虚拟段数,P为任务组数,确保在时间限制内完成。 + +--- + +#### 难题二:决策质量缺陷 + +**问题**:"零存整取"陷阱 - 在物理段粒度做驱逐评分导致短视决策。 + +**典型驱逐评分公式**: +$$ +\text{score} = \frac{\text{distance\_to\_next\_use}}{\sqrt{\text{size}}} +$$ + +**公式解释**: +- **分子 distance_to_next_use**: 段到下次被访问的"距离"(任务组数) + - 距离越远 → 评分越高 → 越适合驱逐(Belady 最优替换的启发式) + - 理由:远期才用到的数据,现在驱逐影响小 + +- **分母 √size**: 段大小的平方根 + - 大小越大 → 评分越低 → 越不适合驱逐 + - 平方根而非线性:在大小段之间取得平衡,避免过度惩罚大段 + - 理由:大段重载成本高(cost = 40 × size),应优先保护 + +- **评分逻辑**: score 越高 → 优先级越高 → 越容易被驱逐 + - 理想目标:驱逐"远期使用 + 小段"的数据 + - 最小化重新加载的总成本 + + + + +**陷阱机制**: + +**物理段视角的缺陷**: +在评分函数中,大对象被离散化后的每个小碎片会因为 size 小而获得较高的 score,使其看起来是"理想"的驱逐目标。算法无法识别这些碎片实际属于同一个逻辑对象,导致: +- **局部最优决策**:每次驱逐都选择了"当前看起来最合理"的碎片 +- **累积性错误**:多次驱逐同一对象的不同碎片 +- **"零存整取"陷阱**:零散驱逐(看似低成本)→ 后续整体重载(实际高成本) + +**虚拟段视角的纠正**: +虚拟段将同一逻辑对象的所有碎片聚合为一个调度单元,其 size 反映了对象的真实大小。评分函数能够: +- **识别整体价值**:大对象因 size 大而获得较低 score,受到保护 +- **全局最优决策**:避免逐步蚕食重要对象 +- **稳定性保证**:防止同一对象的碎片在多次驱逐中被反复选中 + +> 虚拟段抽象将评分粒度从"碎片"提升至"对象",实现从**局部贪心**到**全局最优**的转变。这是 Belady 最优页面替换算法的启发式实现。 + +--- + +### 1.3.2 为什么需要三级混合驱逐策略? + +**动机**:单一驱逐策略无法同时兼顾**效率**、**精确性**和**最优性**。 + +驱逐决策面临的根本矛盾: +- **快速决策** vs **最优决策**:寻找全局最优驱逐方案是 NP-hard 问题 +- **确定性** vs **灵活性**:某些段必须驱逐(垃圾),某些段需要权衡 +- **一次性解决** vs **增量调整**:不同内存压力需要不同粒度的驱逐 + +#### 优先级 1: 垃圾段(Garbage Segments) + +**识别标准**:`next_use = ∞`(不再被任何后续任务组访问) + +**理由**: +- **确定性最高**:这些段必然不会再被使用,驱逐无任何代价 +- **零风险决策**:不存在"过早驱逐"的问题 +- **立即执行**:无需评分比较,发现即驱逐 + +**收益**:O(1) 决策时间,避免无用数据占用内存 + +--- + +#### 优先级 2: 求解器候选(Solver Candidates) + +**识别标准**:单次完整驱逐该段可完全满足当前内存需求 + +**理由**: +- **一次性解决**:避免多次驱逐带来的累积 I/O 成本 +- **操作原子性**:一个驱逐决策 + 一个驱逐操作,状态变化可预测 +- **减少碎片化**:整段驱逐优于多次部分驱逐 + +**决策逻辑**: +在所有"求解器候选"中选择评分最高的(距离最远的),平衡了: +- **满足需求**:腾出足够空间 +- **最小影响**:选择未来最晚使用的段 + +**收益**:单次操作解决问题,避免反复驱逐的抖动 + +--- + +#### 优先级 3: 贪心候选(Greedy Candidates) + +**识别标准**:不存在垃圾段和求解器候选,或需要精细控制 + +**理由**: +- **最优性追求**:基于评分函数选择全局"最不值得保留"的段 +- **部分驱逐能力**:可以只驱逐段的一部分(通过区间管理) +- **灵活性最高**:适应各种复杂场景 + +**评分函数**: +$$ +\text{score} = \frac{\text{distance}}{\sqrt{\text{size}}} - 0.5 \times \text{lock\_penalty} +$$ +- 综合考虑距离(未来价值)和大小(重载成本) +- lock_penalty 确保不驱逐正在使用的段 + +**收益**:在无法一次性解决时,逐步逼近最优解 + +--- + +#### 三级策略的协同效应 + +**自动降级机制**: +``` +优先级 1 → 优先级 2 → 优先级 3 + ↓ ↓ ↓ +确定性 一次性 最优性 +O(1) O(V) O(V log V) +``` + +**为什么不用单一策略?** + +| 策略 | 问题 | +|------|------| +| 只用垃圾段驱逐 | 很多场景无垃圾段可用,无法解决内存不足 | +| 只用求解器 | 可能无法找到单次驱逐即满足的段,导致死锁 | +| 只用贪心 | 每次都需要排序评分,O(V log V) 开销大;且可能多次小驱逐而非一次大驱逐 | + +**三级混合的优势**: +1. **效率递进**:优先使用 O(1) 的垃圾段,其次 O(V) 的求解器,最后才 O(V log V) 的贪心 +2. **适应性强**:覆盖高压、中压、低压三种内存场景 +3. **操作最优**:优先整段操作,减少碎片化状态管理 +4. **决策鲁棒**:有明确的降级路径,不会出现"无策略可用"的情况 + +--- + +### 1.3.3 为什么需要松弛感知的前瞻预取? + +**动机**:最大化计算-I/O 重叠,充分利用 NPU 计算期间的 I/O 空闲窗口。 + +#### 问题根源:I/O 与计算的时间错配 + +**传统被动加载的困境**: +``` +时间线: + t0: 任务组 G 开始执行 + t1: G 发现需要段 S(未在内存) + t2: 等待 I/O 加载 S(40 × size 周期) + t3: S 加载完成 + t4: G 继续执行 + +问题:t1 → t3 期间,NPU 完全空闲,浪费计算资源 +``` + +**理想状态**:在 NPU 计算当前任务组时,**预先加载**后续任务组需要的段。 + +--- + +#### 为什么需要"松弛感知"? + +**简单预取的失败**: +如果不考虑时间约束,盲目预取会导致: +1. **过早预取**:段加载后长时间不用 → 可能在使用前被驱逐 → 浪费 I/O +2. **过晚预取**:段在需要时还未加载完成 → NPU 仍需等待 → 无效预取 + +**松弛分数的引入**: +$$ +\text{slack} = \frac{T_{\text{deadline}} - T_{\text{current}}}{\rho} +$$ + +- **$T_{\text{deadline}}$**:段必须加载完成的截止时间(任务组开始时间) +- **$T_{\text{current}}$**:当前时间 + 预计加载时间 +- **$\rho$**:I/O 密度因子,调整不同工作负载的紧迫度 + +**松弛分数的含义**: +- `slack < 0`:**Late**(紧急)- 段已经错过最佳加载时机,必须立即加载 +- `slack ≥ 0`:**Early**(宽裕)- 段有足够时间加载,可以优化调度顺序 + +--- + +#### 双重排序策略的必要性 + +**Late 队列(紧急段)**: +- **排序**:FCFS(先到先服务) +- **理由**:这些段已经"迟到",必须尽快加载,避免任务组延迟 +- **优先级**:最高,优先于所有 Early 段 + +**Early 队列(非紧急段)**: +- **排序**:SJF(最短作业优先)+ 距离平局 +- **理由**: + 1. **小段优先**:快速完成小段加载,减少队列长度 + 2. **距离平局**:相同大小时,优先近期使用的段 +- **优先级**:次于 Late,但可以优化全局吞吐 + +**为什么不统一排序?** + +| 策略 | 问题 | +|------|------| +| 全部 FCFS | 无法优化 I/O 吞吐量,大段阻塞小段 | +| 全部 SJF | 紧急大段可能被延迟,导致任务组等待 | +| 全部按距离 | 忽略加载时间,可能导致近期段来不及加载 | + +**双重队列的优势**: +- **紧急保证**:Late 队列确保不会因优化而错过截止时间 +- **吞吐优化**:Early 队列最大化 I/O 利用率 +- **动态平衡**:段的状态会从 Early 转为 Late,自动调整优先级 + +--- + +#### 设置密度因子的理由 + +**问题**:不同工作负载的 I/O 强度差异巨大 + +**I/O 密集型任务**: +- 特征:频繁的小段访问,带宽需求高 +- 问题:如果用统一松弛阈值,容易被判定为 "Early" +- 后果:实际上很紧急,但被延后加载 + +**计算密集型任务**: +- 特征:少量大段访问,带宽需求低 +- 问题:如果用统一松弛阈值,容易被判定为 "Late" +- 后果:实际上不紧急,但占用预取资源 + +**密度因子的作用**: +$$ +\rho = \max\left(1.0, \frac{\text{bandwidth}}{\text{baseline}}\right) +$$ + +**参数说明**: +- **bandwidth**:任务组的实际 I/O 带宽需求 + - 计算公式:$\text{bandwidth} = \frac{\text{data\_demand}}{\text{duration}}$ + - `data_demand`:任务组需要的数据总量 + - `duration`:任务组的计算持续时间 + - 含义:单位时间需要传输的数据量 + +- **baseline**:基线带宽 = 0.025 + - 来源:$\frac{1}{40}$(每字节 I/O 成本为 40,周期的倒数) + - 作用:作为归一化参考值 + +**示例**: +``` +任务组 G: + - data_demand = 1000 + - duration = 20000 + - bandwidth = 1000 / 20000 = 0.05 + - ρ = max(1.0, 0.05 / 0.025) = max(1.0, 2.0) = 2.0 + +解释:该任务组的 I/O 需求是基线的 2 倍,属于 I/O 密集型 +``` + +- **I/O 密集**:$\rho > 1$ → slack 被放大 → 更容易进入 Late 队列 → **提前预取** +- **计算密集**:$\rho = 1$ → slack 正常 → 按标准流程 → **延后预取** + +**效果**:不同工作负载获得公平的预取机会 + +--- + +#### 智能交换的必要性 + +**场景**:预取时发现内存不足 + +**传统方案**: +- 放弃预取:浪费 I/O 空闲时间 +- 强行驱逐:可能驱逐即将使用的段 + +**智能交换策略**: +1. 寻找"可替换段":$\text{next\_use}(\text{victim}) > \text{next\_use}(\text{target})$ +2. 评估交换成本:$\text{offload}(\text{victim}) + \text{reload}(\text{target})$ +3. 仅在有净收益时执行交换 + +**收益**: +- 充分利用 I/O 窗口 +- 避免无效交换 +- 保持内存状态最优 + +--- + + +### 1.3.4 为什么需要工作负载自适应策略选择? + +没有一种单一策略能在所有工作负载下都达到最优,必须根据瓶颈特征动态选择策略。 + +#### 问题根源:工作负载的多样性 + +不同的输入数据会导致截然不同的系统瓶颈: + +**场景 1:内存高压场景** +- 特征:峰值内存利用率 > 95% +- 瓶颈:内存容量严重不足 +- 失败的策略:启用虚拟段合并 +- 原因:合并后的大虚拟段难以精确控制,容易超出内存限制 + +**场景 2:I/O 密集场景** +- 特征:频繁的小段访问,带宽需求 > 基准值 1.05 倍 +- 瓶颈:I/O 带宽饱和 +- 失败的策略:浅层预取(lookahead 小) +- 原因:无法充分利用计算空闲时间,I/O 成为长尾瓶颈 + +**场景 3:大段主导场景** +- 特征:平均段大小 > 12.5% 内存容量 +- 瓶颈:大段的频繁加载/卸载 +- 失败的策略:禁用虚拟段合并 +- 原因:无法利用批量 I/O 的优势,操作次数过多 + +**场景 4:平衡场景** +- 特征:内存、I/O 都不是主要瓶颈 +- 瓶颈:算法调度开销 +- 最优策略:中等配置,平衡各方面 + +--- + +#### 为什么静态策略会失败? + +**单一策略的困境**: + +| 策略配置 | 适合场景 | 失败场景 | +|---------|---------|---------| +| 始终启用合并 + 深度预取 | I/O 密集 + 大段 | 内存高压 | +| 始终禁用合并 + 浅预取 | 内存高压 | I/O 密集 + 大段 | +| 中等固定配置 | 平衡场景 | 极端场景 | + +**核心矛盾**: +- **虚拟段合并**:提升 I/O 效率 ↔ 增加内存压力 +- **深度预取**:降低 I/O 延迟 ↔ 占用更多内存 +- **大段优化**:减少操作次数 ↔ 粒度控制困难 + +**结论**:必须根据当前工作负载的**主要矛盾**动态选择策略。 + +--- + +#### 三大特征指标的设计 + +**1. max_p(内存压力)** + +**定义**:峰值内存利用率 = $\frac{\text{max}(\text{memory\_usage})}{\text{M}}$ + +**阈值**:≥ 0.95 + +**决策**: +- `max_p ≥ 0.95` → **禁用合并**(`enable_merge = False`) +- 理由:内存已接近极限,必须精确控制每个段 + +**失败案例**(如果不检测): +``` +场景:M = 1000,实际需求峰值 980 +启用合并:虚拟段大小不可控,可能瞬间超过 1000 → 崩溃 +禁用合并:每个物理段独立管理,可精确卸载到 980 以下 → 成功 +``` + +--- + +**2. io_density(I/O 密度)** + +**定义**:滑动窗口内的平均 I/O 带宽需求(相对于基准值 0.025) + +**阈值**:> 1.05 + +**决策**: +- `io_density > 1.05` → **深度预取**(eg. `lookahead = 200`) +- `io_density ≤ 1.05` → **标准预取**(eg. `lookahead = 100` 或 `50`) + +**理由**: +- I/O 密集时,必须激进预取,充分利用计算空闲时间 +- I/O 宽裕时,过度预取会浪费内存 + +**滑动窗口的必要性**: +- 全局平均会掩盖局部峰值 +- 滑动窗口(例如 50 个任务组)能捕获局部 I/O 热点 + +--- + +**3. avg_seg_size(平均段大小)** + +**定义**:相对于内存容量的平均段大小 = $\frac{\text{total\_seg\_size}}{\text{num\_segs} \times \text{M}}$ + +**阈值**:> 12.5% (0.125) + +**决策**: +- `avg_seg_size > 0.125 × M` → **启用批量合并**(`enable_merge = True`) +- 理由:大段场景下,合并能显著减少操作次数 + +**相对阈值的意义**: +- 相对阈值(12.5% M): + - M = 1000 → 阈值 = 125 + - M = 10000 → 阈值 = 1250 +- 确保在不同内存规模下都能正确识别"大段" + +--- + +#### 四级决策树 + +``` +决策流程: + + 检查 max_p ≥ 0.95? + ├─ 是 → 高内存压力 + │ enable_merge = False + │ base_lookahead = 100 + │ lookahead = adjust_lookahead(100, num_groups, max_p, io_density) + │ → 实际值根据任务组规模、内存压力、IO密度动态调整 + │ + └─ 否 → 内存宽裕 + ├─ 检查 io_density > 1.05? + │ ├─ 是 → I/O 瓶颈 + │ │ enable_merge = True + │ │ base_lookahead = 200 + │ │ lookahead = adjust_lookahead(200, num_groups, max_p, io_density) + │ │ → 小规模会缩小,IO密集会放大 + │ │ + │ └─ 否 → I/O 宽裕 + │ ├─ 检查 avg_seg_size > 0.125M? + │ │ ├─ 是 → 大段场景 + │ │ │ enable_merge = True + │ │ │ base_lookahead = 100 + │ │ │ lookahead = adjust_lookahead(100, ...) + │ │ │ + │ │ └─ 否 → 平衡场景 + │ │ enable_merge = True + │ │ base_lookahead = 50 + │ │ lookahead = adjust_lookahead(50, ...) +``` + +**动态调整说明**: +- 所有lookahead值都不是固定的,而是通过 `adjust_lookahead()` 根据工作负载特征动态计算 +- 调整因子: + - **任务组规模**:≤100时限制在20-50%,避免小规模场景过度预取 + - **内存压力**:高压时降低20% + - **IO密度**:密集时提升30% + +**优先级顺序的理由**: +1. **内存压力优先**:内存溢出是致命错误,必须首先避免 +2. **I/O 密度次之**:I/O 瓶颈影响全局性能 +3. **段大小最后**:在前两者都不是瓶颈时,优化操作粒度 + +--- + +## 2. 整体架构 + +### 2.1 系统组成 + +**整体流程**: + +``` +Selector (策略选择器) + ├─ 分析工作负载特征 + │ ├─ max_p (内存压力) + │ ├─ io_density (IO密度) + │ └─ avg_seg_size (平均段大小) + └─ 选择策略参数 (enable_merge, lookahead) + ↓ +UniversalPlanner (通用规划器) + ├─ 阶段 1: 输入解析 + ├─ 阶段 2: 地址空间离散化 + ├─ 阶段 3: 任务分组与签名生成 + ├─ 阶段 4: 虚拟段创建 + ├─ 阶段 5: 构建辅助数据结构 + ├─ 阶段 6: 初始化运行时状态 + ├─ 阶段 7: 主调度循环 + │ ├─ 步骤 A: 识别所需加载 + │ ├─ 步骤 B: 驱逐腾出空间 + │ ├─ 步骤 C: 执行加载 + │ ├─ 步骤 D: 执行 NPU 任务 + │ └─ 步骤 E: 预取未来数据 + └─ 阶段 8: 输出结果 + ↓ +OutputBuffer (输出缓冲区) + ├─ 合并连续的 Reload 操作 + ├─ 保持 Offload 操作独立 + └─ 按时间戳排序输出 +``` + +### 2.2 数据流图 + +``` +输入数据 (stdin) + ↓ +[解析] → 请求列表 (requests_raw) + ↓ +[离散化] → 物理段 (temp_starts, temp_sizes) + ↓ +[签名生成] → 访问签名 (temp_signatures) + ↓ +[虚拟段创建] → 虚拟段 (v_seg_*, enable_merge 控制) + ↓ +[调度循环] + ├─→ [驱逐] → Offload 操作 + ├─→ [加载] → Reload 操作 + ├─→ [执行] → Visit 操作 + └─→ [预取] → Reload 操作 + ↓ +输出缓冲区 → 格式化输出 +``` + +--- + +## 3. 核心组件详解 + +### 3.1 OutputBuffer (输出缓冲区) + +**功能**: 管理内存操作的缓冲和优化 + +**关键特性**: +- **智能合并**: 仅合并时间和地址都连续的 `Reload` 操作 +- **安全性**: `Offload` 操作永不合并,防止内存一致性问题 +- **排序输出**: 按时间戳排序确保操作顺序正确 + +**代码示例**: +```python +out_buf = OutputBuffer() +out_buf.append(0, "Reload", 100, 50) # 时间0, 加载地址100, 50字节 +out_buf.append(2000, "Reload", 150, 50) # 时间2000, 加载地址150, 50字节 +# 如果 2000 == 0 + 40*50 且 150 == 100+50, 则会合并 +``` + +**合并条件**: +```python +if last_end_time == time and last_end_addr == physical_addr: + # 合并: 扩展最后一个操作 + self.buffer[-1] = (last[0], last[1], last[2], last[3] + size) +``` + +### 3.2 UniversalPlanner (通用规划器) + +**核心参数**: +- `enable_merge` (bool): 是否启用虚拟段合并 + - `True`: 合并具有相同签名的物理段 → 减少 I/O 次数 + - `False`: 保持段独立 → 精确内存控制 +- `lookahead` (int): 预取前瞻深度(50-200) + - 值越大 → 预取越激进 → 适合 IO 密集型 + - 值越小 → 预取越保守 → 适合内存受限场景 + +#### 3.2.1 虚拟段 + +**概念**: 将具有相同访问模式的物理段合并为逻辑单元 + +**签名 (Signature)**: 访问该段的任务组索引元组 +```python +# 示例: 段被任务组 0, 2, 5 访问 +signature = (0, 2, 5) +``` + +**合并策略**: +```python +if enable_merge: + key = signature # 相同签名的段合并到同一虚拟段 +else: + key = (segment_id,) # 每个物理段独立为虚拟段 +``` + +**数据结构**: +```python +v_seg_total_size[vid] # 虚拟段总大小 +v_seg_access_groups[vid] # 访问该虚拟段的任务组列表 +v_seg_sub_segments[vid] # 物理子段列表 [(addr, size), ...] +v_seg_loaded_ranges[vid] # 已加载范围 [(start, end), ...] +``` + +#### 3.2.2 区间管理系统 + +**目的**: 跟踪虚拟段的已加载/未加载部分 + +**核心函数**: + +1. **`add_loaded_range(vid, start, end)`**: 标记逻辑范围为已加载 + ```python + # 添加范围并自动合并重叠区间 + v_seg_loaded_ranges[vid].append((start, end)) + v_seg_loaded_ranges[vid] = merge_intervals(...) + ``` + +2. **`remove_loaded_range(vid, start, end)`**: 标记逻辑范围为已卸载 + ```python + # 从区间集合中移除指定范围(可能分割现有区间) + for r_start, r_end in v_seg_loaded_ranges[vid]: + if 有重叠: + 分割区间并保留非重叠部分 + ``` + +3. **`merge_intervals(intervals)`**: 合并重叠或相邻区间 + ```python + # 输入: [(0, 10), (5, 15), (20, 30)] + # 输出: [(0, 15), (20, 30)] + ``` + +#### 3.2.3 辅助函数详解 + +**1. `emit_reload_virtual(time, vid, logical_end)`** +- **目的**: 为虚拟段的缺失部分生成 `Reload` 操作 +- **算法**: + 1. 计算缺失范围 = 目标范围 - 已加载范围 + 2. 将逻辑范围映射到物理地址 + 3. 生成连续的 `Reload` 操作 + 4. 更新全局 I/O 时间线和内存使用量 + +**2. `emit_offload_virtual(time, vid, amount)`** +- **目的**: 使用 LIFO 策略卸载虚拟段数据 +- **LIFO 理由**: 最近加载的数据可能最不重要(局部性原理的逆向应用) +- **算法**: + 1. 按逆序遍历已加载范围 + 2. 优先卸载最后加载的范围 + 3. 生成 `Offload` 操作并更新状态 + +**3. `get_next_use(vid, g_idx)`** +- **目的**: 获取虚拟段的下次访问时间(用于驱逐决策) +- **优化**: 使用缓存指针实现 O(1) 均摊查找 +- **返回**: 下一个任务组索引,或 `INFINITY_USE`(垃圾段) + +**4. `find_best_victim(needed, protect_set, cur_g_idx)`** +- **目的**: 选择最佳驱逐候选以腾出空间 +- **评分公式**: + ```python + score = (next_use_distance - current_group) / sqrt(segment_size) + ``` + - 优先驱逐:远期使用 + 小段(减少重新加载开销) + +- **三级优先级**: + 1. **垃圾段** (next_use = ∞): 最高优先级 + 2. **求解器候选** (单次驱逐解决赤字): 高优先级 + 3. **贪心候选** (部分驱逐): 低优先级 + +- **惩罚机制**: + ```python + final_score = score_val - (wait_time * 0.5) # 惩罚被 NPU 锁定的段 + ``` + +### 3.3 Selector (策略选择器) + +**功能**: 根据工作负载特征自动选择最优策略 + +**三大关键指标**: + +1. **max_p (内存压力)**: 峰值 HBM 利用率 + ```python + max_p = max(任务组内存需求 / HBM容量) + ``` + +2. **io_density (IO密度)**: 单位时间的数据带宽需求 + ```python + # 使用滑动窗口检测局部峰值 + window_io = sum(最近10个任务组的IO成本) + window_dur = sum(最近10个任务组的计算时间) + local_density = window_io / window_dur + ``` + +3. **avg_seg_size (平均段大小)**: 内存粒度 + ```python + avg_seg_size = 总地址空间 / 物理段数量 + ``` + +**决策逻辑** (优先级从高到低): + +```python +# 步骤1: 计算工作负载特征 +num_groups = len(任务组列表) +actual_io_density = max(窗口峰值IO密度, 全局平均IO密度) / 0.025 + +# 步骤2: 根据特征选择基础策略和lookahead +if max_p >= 0.95: + # 情况1: 内存高压 - 生存第一 + # 策略: 关闭合并 + 保守预取 + # 原因: 避免内存颠簸,精确控制内存使用 + base_lookahead = 100 + lookahead = adjust_lookahead(base_lookahead, num_groups, max_p, actual_io_density) + UniversalPlanner(enable_merge=False, lookahead=lookahead) + +elif max_window_io_density > 1.2 or avg_io_density > 1.05: + # 情况2: IO 瓶颈 - 带宽优先 + # 策略: 开启合并 + 深度预取 + # 原因: 最大化带宽利用,流式加载 + base_lookahead = 200 + lookahead = adjust_lookahead(base_lookahead, num_groups, max_p, actual_io_density) + UniversalPlanner(enable_merge=True, lookahead=lookahead) + +elif avg_seg_size > M * 0.125: + # 情况3: 大段场景 - IO 效率优先 + # 策略: 开启合并 + 中等预取 + # 原因: 大段适合批量加载 + base_lookahead = 100 + lookahead = adjust_lookahead(base_lookahead, num_groups, max_p, actual_io_density) + UniversalPlanner(enable_merge=True, lookahead=lookahead) + +else: + # 情况4: 默认平衡 + # 策略: 开启合并 + 浅预取 + base_lookahead = 50 + lookahead = adjust_lookahead(base_lookahead, num_groups, max_p, actual_io_density) + UniversalPlanner(enable_merge=True, lookahead=lookahead) +``` + +**关键**: +- `lookahead` 通过 `adjust_lookahead()` 根据**任务组规模**、**内存压力**和**IO密度**动态调整 +- 小规模场景(≤100任务组):自动限制在任务组总数的20-30% +- 内存高压场景:进一步降低20%,避免内存溢出 +- IO密集场景:适度提升30%,充分利用I/O空闲时间 + +--- + +## 4. 算法流程 + +### 4.1 预处理阶段 (阶段 1-6) + +#### 阶段 1: 输入解析 +```python +L = 地址空间大小 +M = HBM 容量 +N = 请求数量 +requests = [(地址, 大小, 开始时间, 持续时间, ID), ...] +``` + +#### 阶段 2: 地址空间离散化 +**目的**: 将连续地址空间分割为不重叠的物理段 + +```python +# 1. 收集所有边界点 +points = {0, L} +for request in requests: + points.add(request.addr) + points.add(request.addr + request.size) + +# 2. 构建物理段 +sorted_points = sorted(points) +for i in range(len(sorted_points) - 1): + segment = (sorted_points[i], sorted_points[i+1] - sorted_points[i]) +``` + +**示例**: +``` +请求: [addr=0, size=100], [addr=50, size=100] +边界点: {0, 50, 100, 150} +物理段: [0-50), [50-100), [100-150) +``` + +#### 阶段 3: 任务分组与签名生成 +**目的**: 识别每个段的访问模式 + +```python +# 1. 按开始时间分组请求 +groups[start_time] = [request_ids, ...] + +# 2. 为每个物理段生成访问签名 +for group_idx, group in enumerate(groups): + for request in group: + for segment in request覆盖的段: + segment.signature.append(group_idx) + +# 示例签名 +segment[0].signature = (0, 2, 5) # 被任务组 0, 2, 5 访问 +segment[1].signature = (1, 3) # 被任务组 1, 3 访问 +``` + +#### 阶段 4: 虚拟段创建 +**目的**: 合并具有相同访问模式的物理段 + +```python +if enable_merge: + # 相同签名的段合并 + for segment in 物理段: + key = segment.signature + virtual_segments[key].append(segment) +else: + # 每个物理段独立 + for segment in 物理段: + virtual_segments[(segment.id,)] = [segment] +``` + +**效果对比**: +``` +物理段: S1(签名A), S2(签名A), S3(签名B) + +enable_merge=True: + V1 = [S1, S2] # 签名A + V2 = [S3] # 签名B + +enable_merge=False: + V1 = [S1] + V2 = [S2] + V3 = [S3] +``` + +#### 阶段 5: 构建辅助数据结构 +```python +# 任务组 → 需要的虚拟段 +group_needed_vsegs[g_idx] = [vid1, vid2, ...] + +# 任务组元数据 +group_info[g_idx] = (start_time, duration, [request_ids]) + +# IO 密度因子 +group_density_factor[g_idx] = 任务组带宽需求 / 基线带宽 +``` + +#### 阶段 6: 初始化运行时状态 +```python +v_seg_loaded_ranges = [[] for _ in range(num_v_segs)] # 区间集合 +v_seg_unlock_time = [0] * num_v_segs # NPU 锁定时间 +loaded_v_seg_indices = set() # 已加载集合 +current_hbm_usage = 0 # 当前内存使用 +T_io = 0 # 全局 I/O 时间线 +``` + +### 4.2 主调度循环 (阶段 7) + +对每个任务组执行以下步骤: + +#### 步骤 A: 识别所需加载 +```python +to_load = [] +for vid in 当前任务组需要的虚拟段: + 未加载量 = 虚拟段总大小 - 已加载量 + if 未加载量 > 0: + to_load.append(vid) + need_bytes += 未加载量 +``` + +#### 步骤 B: 驱逐腾出空间 +```python +while current_hbm_usage + need_bytes > M: + victim_id, amount = find_best_victim(...) + if victim_id != -1: + emit_offload_virtual(T_io, victim_id, amount) + else: + break # 无法腾出足够空间 +``` + +**驱逐决策流程**: +``` +1. 扫描所有已加载段 + ↓ +2. 过滤保护集和锁定段 + ↓ +3. 分类候选: + - 垃圾段 (next_use = ∞) + - 求解器候选 (可一次性解决赤字) + - 贪心候选 (其他) + ↓ +4. 按优先级选择: + 垃圾 > 求解器 > 贪心 + ↓ +5. 返回 (victim_id, offload_amount) +``` + +#### 步骤 C: 执行加载 +```python +# 按物理地址排序,减少寻道时间 +to_load.sort(key=lambda v: 虚拟段的起始物理地址) + +for vid in to_load: + emit_reload_virtual(T_io, vid, 虚拟段总大小) +``` + +#### 步骤 D: 执行 NPU 任务 +```python +T_visit_start = max(T_npu_ready, start_time, T_data_ready) + +# 生成 Visit 操作 +for request_id in 任务组的请求: + out_buf.append(T_visit_start, "Visit", request_id, 0) + +T_group_end = T_visit_start + duration + +# 锁定所有需要的段 +for vid in 需要的虚拟段: + v_seg_unlock_time[vid] = max(v_seg_unlock_time[vid], T_group_end) +``` + +#### 步骤 E: 预取未来数据 + +**目的**: 利用 NPU 计算期间的空闲 I/O 时间 + +**预取窗口**: `[当前任务组+1, 当前任务组+lookahead]` + +**算法**: +1. **构建候选列表**: + ```python + for pg in range(current_group+1, current_group+lookahead): + for vid in 任务组pg需要的虚拟段: + if vid未完全加载: + slack = deadline - (T_io + load_cost) # 松弛时间 + distance = pg - current_group # 距离 + candidates.append((slack, distance, vid, ...)) + ``` + +2. **松弛分数 (Slack Score)**: + ```python + slack = (预期开始时间 - 预计完成时间) / 密度因子 + ``` + - `slack < 0`: 紧急(会延迟任务) + - `slack > 0`: 有余裕 + +3. **排序策略**: + ```python + late = [c for c in candidates if c.slack < 0] # 紧急 + early = [c for c in candidates if c.slack >= 0] # 非紧急 + + late.sort(by=任务组索引) # FCFS (先到先服务) + early.sort(by=(slack, distance)) # SJF + 距离平局 + + prefetch_order = late + early # 紧急优先 + ``` + +4. **执行预取**: + ```python + for candidate in prefetch_order: + 剩余时间 = T_group_end - T_io + if 剩余时间 <= 0: + break + + 可加载量 = min(未加载量, 剩余时间 // 40) + + # 内存不足时尝试交换 + if current_hbm_usage + 可加载量 > M: + 执行智能交换(...) + + # 执行预取 + if 可加载量 > 0: + emit_reload_virtual(T_io, vid, ...) + ``` + +5. **智能交换策略**: + ```python + # 查找交换受害者 + for loaded_vid in 已加载段: + if locked或protected: + continue + + next_use = get_next_use(loaded_vid, current_group) + + if next_use == INFINITY: + 立即使用为交换对象 # 垃圾段 + elif next_use > 预取目标的next_use: + score = (next_use - 目标) / sqrt(段大小) + 记录最高分段 + + # 执行交换 + if 找到合适受害者 且 交换成本合理: + emit_offload受害者 + emit_reload预取目标 + ``` + +### 4.3 输出阶段 (阶段 8) + +```python +ops = out_buf.flush() # 排序并格式化 +sys.stdout.write('\n'.join(ops)) +sys.stdout.write(f"\nFin {max(last_group_compute_end, T_io)}\n") +``` + +--- + + +## 5. 使用说明及示例展示 + +### 5.1 使用说明 + +```bash +# 运行调度器 +python3 code.py < input.txt > output.txt + +# 使用检查器验证 +./checker output.txt input.txt +``` + +### 5.2 示例展示 + +输入示例: + +``` +300 150 3 # L=300, M=150, N=3 +0 100 0 50 # Req0: addr=0, size=100, start=0, dur=50 +100 50 1000 10 # Req1: addr=100, size=50, start=1000, dur=10 +200 50 1000 10 # Req2: addr=200, size=50, start=1000, dur=10 +``` + +#### 1. 场景特征分析 +- **内存压力**: 峰值占用 100/150 = 66.7% (< 95%) → **非内存受限** +- **I/O 密度**: 极高。任务组 0 需要加载 100 字节 (4000 周期),但计算仅 50 周期。I/O 需求远超计算时间。 +- **段大小**: 平均段大小 66.7 字节 (> 12.5% M) → **大段场景** + +#### 2. 策略选择 +根据决策树优先级: +1. 检查内存压力 → 未超标 +2. 检查 I/O 密度 → **严重超标 (> 1.2)** + +**最终决策**: 命中 **I/O 瓶颈策略**。 +- **启用合并 (`enable_merge=True`)**: 将 Req1(100-150) 和 Req2(200-250) 合并为单一虚拟段,减少 I/O 头部开销。 +- **深度预取 (`lookahead=High`)**: 利用 Group 0 的微小计算窗口 (50 周期) 尽可能预取 Group 1 的数据(即使只能预取 1 字节),最大化带宽利用率。 + +#### 3. 输出结果及可视化 + +``` +Reload 0 0 100 +Visit 4000 0 +Reload 4000 100 1 +Offload 4050 0 100 +Reload 8050 101 49 +Reload 10010 200 50 +Visit 12010 1 +Visit 12010 2 +Fin 12020 +``` + +![](./1.png) + +本例的关键是`Reload 4000 100 1`的存在比较反直觉,分析可知,在 T=4000 到 T=4050 这段时间内,NPU 在忙碌,而 I/O 通道是空闲的。这是一个 50 周期 的“预取窗口”。算法检测到了这个空闲窗口,并尝试利用它来预取下一个任务组(Group 1)需要的数据。这证明了算法具有极高的 I/O 敏感度,即使是微小的碎片时间也不会放过,这正是它在 I/O 瓶颈场景下高效的原因。 + +此示例完美展示了算法在极端 I/O 密集场景下,如何通过激进的预取和合并策略来压榨每一分 I/O 性能。 + +--- + +## 复杂度分析与说明 + +### A. 时间复杂度分析 + +| 阶段 | 时间复杂度 | 说明 | +|------|-----------|------| +| 输入解析 | O(N) | N = 请求数量 | +| 地址空间离散化 | O(N log N) | 排序边界点 | +| 签名生成 | O(N × S) | S = 平均段覆盖数 | +| 虚拟段创建 | O(P) | P = 物理段数量 | +| 主调度循环 | O(G × V) | G = 任务组数, V = 虚拟段数 | +| 驱逐决策 | O(V) | 扫描已加载段 | +| 预取候选排序 | O(L log L) | L = lookahead深度 | +| **总体** | **O(N log N + G × V)** | | + +### B. 空间复杂度分析 + +| 数据结构 | 空间复杂度 | 说明 | +|---------|-----------|------| +| 请求列表 | O(N) | 原始请求 | +| 物理段 | O(P) | P ≈ 2N (worst case) | +| 虚拟段 | O(V) | V ≤ P | +| 区间集合 | O(V × K) | K = 平均区间数 | +| 输出缓冲区 | O(N + G) | Reload + Visit + Offload | +| **总体** | **O(N + V × K)** | | + +### C. 关键常量说明 + +| 常量 | 值 | 说明 | +|------|---|------| +| `INFINITY_USE` | 10^18 | 永不使用标记 | +| `递归限制` | 50000 | Python 栈深度 | +| `I/O 成本` | 40 周期/字节 | 硬件参数 | +| `密度基线` | 0.025 字节/周期 | 带宽基准 | +| `窗口大小` | 10 任务组 | IO密度分析 | diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/code/code.py" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/code/code.py" new file mode 100644 index 0000000000000000000000000000000000000000..933909bf591fa25070b17d351fe0bae964c30cdb --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/code/code.py" @@ -0,0 +1,1127 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +自适应签名驱动虚拟段调度算法 +================================================ + +本算法实现了一个智能的高带宽内存(HBM)调度系统,用于计算任务中的内存管理。 +它采用基于虚拟段的方法,结合自适应策略进行内存分配、驱逐和预取。 + +核心特性: + - 虚拟段合并以降低 I/O 开销 + - 基于前瞻的自适应预取 + - 使用距离评分的智能驱逐选择 + - 基于工作负载特征的动态策略选择 +""" + +import sys +import math +import gc + +# ========================================== +# 配置常量 (Configuration Constants) +# ========================================== + +# 提升递归深度限制,用于复杂调度场景中的深层调用栈 +sys.setrecursionlimit(50000) + +# 哨兵值,表示无限时间(用于不再被访问的段) +INFINITY_USE = 10**18 + + +# ========================================== +# 核心组件 (Core Components) +# ========================================== + +class OutputBuffer: + """ + 内存操作的输出缓冲区 + + 管理内存操作(Reload、Offload、Visit)的缓冲和格式化。 + 实现策略性的操作合并以降低 I/O 开销,同时保持正确性。 + + 关键设计决策: 仅合并 Reload 操作;Offload 操作保持独立以防止内存一致性问题。 + + 属性: + buffer (list): 操作元组列表 (time, op_type, physical_addr, size) + """ + + def __init__(self): + """初始化空的输出缓冲区""" + self.buffer = [] + + def append(self, time, op_type, physical_addr, size): + """ + 添加内存操作到缓冲区并智能合并 + + 参数: + time (int): 操作开始时间戳 + op_type (str): 操作类型 - "Reload"、"Offload" 或 "Visit" + physical_addr (int): 物理内存起始地址 + size (int): 操作大小(字节),Visit 操作为 0 + + 注意: + - size=0 的 Visit 操作总是被添加 + - Reload 操作在时间和地址连续时会被合并 + - Offload 操作永不合并以保持内存安全性 + """ + # 跳过大小为 0 的非 Visit 操作 + if size == 0 and op_type != "Visit": + return + + # 尝试与最后一个操作合并 + if self.buffer: + last = self.buffer[-1] + if last[1] == op_type: + # 关键:只合并 Reload 操作 + if op_type == "Reload": + last_end_time = last[0] + 40 * last[3] # 每字节 40 周期 + last_end_addr = last[2] + last[3] + + # 在时间和地址空间都连续时合并 + if last_end_time == time and last_end_addr == physical_addr: + self.buffer[-1] = (last[0], last[1], last[2], last[3] + size) + return + + # 无法合并时添加新操作 + self.buffer.append((time, op_type, physical_addr, size)) + + def flush(self): + """ + 按时间戳排序操作并格式化为输出 + + 返回: + list[str]: 格式化的操作字符串列表 + + 注意: + 操作按时间戳排序以确保时间顺序。 + Visit 操作格式化时不包含 size 参数。 + """ + self.buffer.sort(key=lambda x: x[0]) + res = [] + for t, op, addr, size in self.buffer: + if op == "Visit": + res.append(f"{op} {t} {addr}") + else: + res.append(f"{op} {t} {addr} {size}") + return res + + +class UniversalPlanner: + """ + 通用内存规划引擎 - 自适应策略 + + 核心调度算法,管理多个任务组的内存操作。 + 支持针对不同工作负载模式的可配置策略。 + + **核心概念**: + - 虚拟段(Virtual Segments): 通过合并具有相同访问模式的物理段创建的逻辑内存单元 + - 签名(Signatures): 定义哪些任务组访问每个段的时间访问模式 + - 前瞻(Lookahead): 预取时考虑的未来任务组数量 + + **策略**: + - enable_merge=True: 合并具有相同签名的段(更快的 I/O) + - enable_merge=False: 保持段独立(精确的内存控制) + - lookahead: 控制预取激进程度(50-200 个任务组) + + 属性: + input_data (list): 来自 stdin 的原始输入数据 + ENABLE_MERGE (bool): 是否合并具有相同签名的虚拟段 + LOOKAHEAD_DEPTH (int): 预取的未来任务组数量 + """ + + def __init__(self, input_data, enable_merge=True, lookahead=100): + """ + 使用指定策略初始化通用规划器 + + 参数: + input_data (list): 来自 stdin 的分词输入数据 + enable_merge (bool): 启用虚拟段合并(默认: True) + lookahead (int): 预取前瞻深度(默认: 100 个任务组) + """ + self.input_data = input_data + self.ENABLE_MERGE = enable_merge + self.LOOKAHEAD_DEPTH = lookahead + + def run(self): + """ + 执行主调度算法 + + 此方法协调整个调度过程: + 1. 解析输入并离散化地址空间 + 2. 构建任务组并计算访问签名 + 3. 基于合并策略创建虚拟段 + 4. 执行主调度循环,包括驱逐和预取 + 5. 输出格式化的操作 + + 算法使用多阶段方法平衡 I/O 效率和内存利用率。 + """ + # 在调度期间禁用垃圾回收以提升性能 + gc.disable() + + if not self.input_data: + print("Fin 0") + return + + # ========================================== + # 阶段 1: 输入解析 (Input Parsing) + # ========================================== + iterator = iter(self.input_data) + try: + self.L = int(next(iterator)) # 总地址空间大小 + self.M = int(next(iterator)) # HBM 容量 + self.N = int(next(iterator)) # 请求数量 + + # 解析所有请求: (地址, 大小, 开始时间, 持续时间, 请求ID) + self.requests_raw = [] + for i in range(self.N): + self.requests_raw.append(( + int(next(iterator)), # 地址 + int(next(iterator)), # 大小 + int(next(iterator)), # 开始时间 + int(next(iterator)), # 持续时间 + i # 请求 ID + )) + except StopIteration: + return + + if self.N == 0: + print("Fin 0") + return + + # ========================================== + # 阶段 2: 地址空间离散化 (Address Space Discretization) + # ========================================== + # 在所有请求边界创建段以避免部分重叠 + points = {0, self.L} + for r in self.requests_raw: + points.add(r[0]) # 起始地址 + points.add(r[0] + r[1]) # 结束地址 + sorted_points = sorted(list(points)) + + # 构建临时物理段 + temp_starts = [] + temp_sizes = [] + addr_to_temp_idx = {} + curr_idx = 0 + + for i in range(len(sorted_points) - 1): + start = sorted_points[i] + size = sorted_points[i+1] - start + if size > 0: + temp_starts.append(start) + temp_sizes.append(size) + addr_to_temp_idx[start] = curr_idx + curr_idx += 1 + + temp_num_segs = len(temp_starts) + + # ========================================== + # 阶段 3: 任务分组与签名生成 (Task Grouping and Signature Generation) + # ========================================== + # 按开始时间分组请求 + groups_map = {} + for i, r in enumerate(self.requests_raw): + s_time = r[2] + if s_time not in groups_map: + groups_map[s_time] = [] + groups_map[s_time].append(i) + + sorted_start_times = sorted(groups_map.keys()) + num_groups = len(sorted_start_times) + + # 计算每个物理段的访问签名 + # 签名 = 访问该段的任务组索引元组 + temp_signatures = [[] for _ in range(temp_num_segs)] + raw_group_infos = [] + + for g_idx, s_time in enumerate(sorted_start_times): + req_indices = groups_map[s_time] + duration = self.requests_raw[req_indices[0]][3] + raw_group_infos.append((s_time, duration, req_indices)) + + # 标记该任务组访问的段 + for r_idx in req_indices: + r = self.requests_raw[r_idx] + start_idx = addr_to_temp_idx[r[0]] + r_size = r[1] + acc = 0 + curr = start_idx + + # 遍历该请求覆盖的所有段 + while acc < r_size: + if not temp_signatures[curr] or temp_signatures[curr][-1] != g_idx: + temp_signatures[curr].append(g_idx) + acc += temp_sizes[curr] + curr += 1 + + # 将签名转换为元组以便哈希 + temp_signatures = [tuple(s) for s in temp_signatures] + + # ========================================== + # 阶段 4: 虚拟段创建 (Virtual Segment Creation) + # ========================================== + # 策略: 合并具有相同签名的段(如果启用) + sig_to_vid = {} + self.v_seg_total_size = [] + self.v_seg_access_groups = [] + self.v_seg_sub_segments = [] + curr_vid = 0 + + for i in range(temp_num_segs): + sig = temp_signatures[i] + if not sig: + continue + + # 核心策略分支: 按签名合并或保持独立 + key = sig if self.ENABLE_MERGE else (i,) + + if key not in sig_to_vid: + sig_to_vid[key] = curr_vid + self.v_seg_total_size.append(0) + self.v_seg_access_groups.append(sig) + self.v_seg_sub_segments.append([]) + curr_vid += 1 + + vid = sig_to_vid[key] + p_start = temp_starts[i] + p_size = temp_sizes[i] + self.v_seg_total_size[vid] += p_size + + # 合并同一虚拟段内的连续物理段 + subs = self.v_seg_sub_segments[vid] + if subs and subs[-1][0] + subs[-1][1] == p_start: + # 扩展最后一个子段 + subs[-1] = (subs[-1][0], subs[-1][1] + p_size) + else: + # 添加新子段 + subs.append((p_start, p_size)) + + num_v_segs = curr_vid + + # 预计算大小的平方根倒数用于评分 + # 用于驱逐选择: 较小的段具有更高的驱逐优先级 + self.inv_sqrt_sizes = [1.0 / (math.sqrt(s) + 1e-5) for s in self.v_seg_total_size] + + # ========================================== + # 阶段 5: 构建辅助数据结构 (Build Auxiliary Data Structures) + # ========================================== + # 将每个任务组映射到其需要的虚拟段 + self.group_needed_vsegs = [[] for _ in range(num_groups)] + group_data_demand = [0] * num_groups + + for vid, g_list in enumerate(self.v_seg_access_groups): + sz = self.v_seg_total_size[vid] + for g_idx in g_list: + self.group_needed_vsegs[g_idx].append(vid) + group_data_demand[g_idx] += sz + + # 构建任务组元数据: (开始时间, 持续时间, 请求IDs) + self.group_info = [] + for s_time, dur, r_indices in raw_group_infos: + self.group_info.append((s_time, dur, [self.requests_raw[x][4] for x in r_indices])) + + # 计算每个任务组的 I/O 密度因子 + # 更高的密度 = 更需要带宽的工作负载 + group_density_factor = [1.0] * num_groups + avg_bw = 0.025 # 基线带宽(1/40) + + for i in range(num_groups): + dur = self.group_info[i][1] + demand = group_data_demand[i] + bw = demand / dur if dur > 0 else demand + group_density_factor[i] = max(1.0, bw / avg_bw) + + # ========================================== + # 阶段 6: 初始化运行时状态 (Initialize Runtime State) + # ========================================== + # 跟踪每个虚拟段的已加载范围(区间集) + v_seg_loaded_ranges = [[] for _ in range(num_v_segs)] + + # 跟踪每个段变得可访问的时间(NPU 完成后) + v_seg_unlock_time = [0] * num_v_segs + + # 下次访问查找的缓存(优化以避免重复扫描) + v_seg_next_access_ptr = [0] * num_v_segs + + # 当前已加载的虚拟段 ID 集合 + loaded_v_seg_indices = set() + + # 当前 HBM 使用量(字节) + current_hbm_usage = 0 + + # 操作的输出缓冲区 + out_buf = OutputBuffer() + + # 全局 I/O 时间线(当前 I/O 完成时间) + T_io = 0 + + # 最后一次 NPU 完成时间(用于依赖跟踪) + last_group_compute_end = 0 + + # 所有任务组的预估开始时间(用于预取) + est_group_start_time = [0] * num_groups + curr_est = 0 + for i in range(num_groups): + curr_est = max(curr_est, self.group_info[i][0]) + est_group_start_time[i] = curr_est + curr_est += self.group_info[i][1] + + # ========================================== + # 辅助函数 + # ========================================== + + def get_loaded_amount(vid): + """ + 计算虚拟段的已加载总字节数 + + 参数: + vid (int): 虚拟段 ID + + 返回: + int: 当前 HBM 中已加载的总字节数 + """ + return sum(end - start for start, end in v_seg_loaded_ranges[vid]) + + def merge_intervals(intervals): + """ + 合并重叠或相邻的区间 + + 参数: + intervals (list): (start, end) 元组列表 + + 返回: + list: 按起始位置排序的合并区间 + """ + if not intervals: + return [] + intervals.sort() + merged = [intervals[0]] + for start, end in intervals[1:]: + if start <= merged[-1][1]: + # 重叠或相邻: 合并 + merged[-1] = (merged[-1][0], max(merged[-1][1], end)) + else: + # 不相交: 添加新区间 + merged.append((start, end)) + return merged + + def add_loaded_range(vid, start, end): + """ + 将逻辑范围标记为已加载 + + 参数: + vid (int): 虚拟段 ID + start (int): 逻辑起始偏移量 + end (int): 逻辑结束偏移量(不包含) + """ + if start >= end: + return + v_seg_loaded_ranges[vid].append((start, end)) + v_seg_loaded_ranges[vid] = merge_intervals(v_seg_loaded_ranges[vid]) + if v_seg_loaded_ranges[vid]: + loaded_v_seg_indices.add(vid) + + def remove_loaded_range(vid, start, end): + """ + 将逻辑范围标记为已卸载 + + 参数: + vid (int): 虚拟段 ID + start (int): 逻辑起始偏移量 + end (int): 逻辑结束偏移量(不包含) + """ + if start >= end: + return + new_ranges = [] + for r_start, r_end in v_seg_loaded_ranges[vid]: + if r_end <= start or r_start >= end: + # 无重叠: 保留区间 + new_ranges.append((r_start, r_end)) + else: + # 有重叠: 分割区间 + if r_start < start: + new_ranges.append((r_start, start)) + if r_end > end: + new_ranges.append((end, r_end)) + v_seg_loaded_ranges[vid] = new_ranges + if not v_seg_loaded_ranges[vid]: + if vid in loaded_v_seg_indices: + loaded_v_seg_indices.remove(vid) + + def emit_reload_virtual(time, vid, logical_end): + """ + 为虚拟段的缺失范围生成 Reload 操作 + + 该函数计算虚拟段的哪些部分尚未加载, + 并为这些范围生成 Reload 操作。 + + 参数: + time (int): 重载操作的开始时间 + vid (int): 虚拟段 ID + logical_end (int): 要加载到的目标逻辑结束偏移量 + + 副作用: + - 更新 T_io(全局 I/O 时间线) + - 更新 current_hbm_usage + - 向输出缓冲区添加 Reload 操作 + """ + nonlocal T_io, current_hbm_usage + + # 计算缺失的范围 + current_ranges = v_seg_loaded_ranges[vid] + missing = [] + + if not current_ranges: + # 没有加载任何内容: 重载整个范围 + missing = [(0, logical_end)] + else: + # 查找已加载范围中的空隙 + covered = 0 + for r_start, r_end in sorted(current_ranges): + if covered < r_start and covered < logical_end: + missing.append((covered, min(r_start, logical_end))) + covered = max(covered, r_end) + if covered < logical_end: + missing.append((covered, logical_end)) + + # 为缺失范围生成 Reload 操作 + op_time = time + total_loaded = 0 + + for l_start, l_end in missing: + acc = 0 + for p_start, p_size in self.v_seg_sub_segments[vid]: + seg_log_end = acc + p_size + overlap_s = max(acc, l_start) + overlap_e = min(seg_log_end, l_end) + + if overlap_s < overlap_e: + load_len = overlap_e - overlap_s + phys = p_start + (overlap_s - acc) + out_buf.append(op_time, "Reload", phys, load_len) + op_time += 40 * load_len # 每字节 40 周期 + total_loaded += load_len + + acc += p_size + if acc >= l_end: + break + + add_loaded_range(vid, l_start, l_end) + + # 更新全局状态 + T_io = op_time + current_hbm_usage += total_loaded + + def emit_offload_virtual(time, vid, amount): + """ + 使用 LIFO 策略为虚拟段生成 Offload 操作 + + 使用后进先出 (LIFO) 驱逐策略: 最近加载的范围 + 优先被卸载。 + + 参数: + time (int): 卸载操作的开始时间 + vid (int): 虚拟段 ID + amount (int): 要卸载的字节数 + + 副作用: + - 更新 T_io(全局 I/O 时间线) + - 更新 current_hbm_usage + - 向输出缓冲区添加 Offload 操作 + """ + nonlocal T_io, current_hbm_usage + + current_loaded = get_loaded_amount(vid) + if current_loaded == 0: + return + + offload_amt = min(amount, current_loaded) + + # LIFO 卸载策略: 从最后加载的范围开始 + ranges_to_offload = [] + acc_off = 0 + + for r_start, r_end in sorted(v_seg_loaded_ranges[vid], reverse=True): + if acc_off >= offload_amt: + break + needed = offload_amt - acc_off + if r_end - r_start <= needed: + # 卸载整个范围 + ranges_to_offload.append((r_start, r_end)) + acc_off += r_end - r_start + else: + # 从末尾部分卸载 + ranges_to_offload.append((r_end - needed, r_end)) + acc_off += needed + + # 生成 Offload 操作 + op_time = time + total_offloaded = 0 + + for l_start, l_end in ranges_to_offload: + acc = 0 + for p_start, p_size in self.v_seg_sub_segments[vid]: + seg_log_end = acc + p_size + overlap_s = max(acc, l_start) + overlap_e = min(seg_log_end, l_end) + + if overlap_s < overlap_e: + off_len = overlap_e - overlap_s + phys = p_start + (overlap_s - acc) + out_buf.append(op_time, "Offload", phys, off_len) + op_time += 40 * off_len + total_offloaded += off_len + + acc += p_size + if acc >= l_end: + break + + remove_loaded_range(vid, l_start, l_end) + + # 更新全局状态 + T_io = op_time + current_hbm_usage -= total_offloaded + + def get_next_use(vid, g_idx): + """ + 获取下一个将访问此虚拟段的任务组索引 + + 使用缓存指针实现 O(1) 均摊查找。 + + 参数: + vid (int): 虚拟段 ID + g_idx (int): 当前任务组索引 + + 返回: + int: 访问此段的下一个任务组索引,或 INFINITY_USE + """ + ac_list = self.v_seg_access_groups[vid] + ptr = v_seg_next_access_ptr[vid] + + # 将指针推进到当前任务组之后 + while ptr < len(ac_list) and ac_list[ptr] <= g_idx: + ptr += 1 + + v_seg_next_access_ptr[vid] = ptr # 缓存供下次查找使用 + return ac_list[ptr] if ptr < len(ac_list) else INFINITY_USE + + def find_best_victim(needed, protect_set, cur_g_idx, io_limit=None): + """ + 使用基于距离的评分查找最佳驱逐虚拟段 + + 评分策略: + 1. 垃圾(无未来使用): 最高优先级 + 2. 求解器候选(单次驱逐解决赤字): 高优先级 + 3. 贪心候选(部分驱逐): 较低优先级 + + 评分公式: distance_to_next_use / sqrt(segment_size) + - 偏好远期使用的段 + - 惩罚大段(避免重新加载开销) + + 参数: + needed (int): 即将操作需要的字节数 + protect_set (set): 不能被驱逐的段 ID 集合 + cur_g_idx (int): 当前任务组索引 + io_limit (int, optional): 候选的最大解锁时间 + + 返回: + tuple: (victim_id, amount_to_offload) 或 (-1, 0) 如果未找到受害者 + """ + deficit = (current_hbm_usage + needed) - self.M + if deficit <= 0: + return (-1, 0) + + # 对候选进行分类 + inf_cands = [] # 垃圾(无未来使用) + solver_cands = [] # 可以一次驱逐解决赤字 + all_cands = [] # 所有有效候选 + + for vid in loaded_v_seg_indices: + if vid in protect_set: + continue + if io_limit and v_seg_unlock_time[vid] > io_limit: + continue + + nu = get_next_use(vid, cur_g_idx) + ut = v_seg_unlock_time[vid] + + if nu == INFINITY_USE: + # 垃圾: 不会再被使用 + inf_cands.append(((1, 0, -ut), vid)) + else: + # 计算优先级分数 + dist = nu - cur_g_idx + score_val = dist * self.inv_sqrt_sizes[vid] + wait = max(0, ut - T_io) + final_score = score_val - (wait * 0.5) # 惩罚被锁定的段 + + all_cands.append(((0, final_score, -ut), vid)) + + # 检查该段是否可以单独解决赤字 + if get_loaded_amount(vid) >= deficit: + solver_cands.append(((0, final_score, -ut), vid)) + + # 优先级 1: 首先驱逐垃圾段 + if inf_cands: + inf_cands.sort(key=lambda x: x[0], reverse=True) + vid = inf_cands[0][1] + return (vid, get_loaded_amount(vid)) # 完全驱逐 + + # 优先级 2: 优先使用求解器候选(一次性解决方案) + best_solver = max(solver_cands, key=lambda x: x[0]) if solver_cands else None + best_greedy = max(all_cands, key=lambda x: x[0]) if all_cands else None + + if best_solver: + # 使用求解器,除非贪心明显更好 + if not best_greedy or best_solver[0][1] * 1.1 > best_greedy[0][1]: + amt = get_loaded_amount(best_solver[1]) + return (best_solver[1], min(amt, deficit)) + + # 优先级 3: 使用贪心候选(部分驱逐) + if best_greedy: + amt = get_loaded_amount(best_greedy[1]) + return (best_greedy[1], min(amt, deficit)) + + # 未找到有效的受害者 + return (-1, 0) + + # ========================================== + # 阶段 7: 主调度循环 (Main Scheduling Loop) + # ========================================== + for g_idx in range(num_groups): + start_time, duration, req_ids = self.group_info[g_idx] + needed = self.group_needed_vsegs[g_idx] + T_npu_ready = last_group_compute_end + + # -------------------------------------------- + # 步骤 A: 识别所需的加载 (Identify Required Loads) + # -------------------------------------------- + to_load = [] + need_bytes = 0 + + for vid in needed: + rem = self.v_seg_total_size[vid] - get_loaded_amount(vid) + if rem > 0: + to_load.append(vid) + need_bytes += rem + + # -------------------------------------------- + # 步骤 B: 驱逐以腾出空间 (Evict to Make Space) + # -------------------------------------------- + while current_hbm_usage + need_bytes > self.M: + vid, amt = find_best_victim(need_bytes, set(needed), g_idx) + if vid != -1: + t_start = max(T_io, v_seg_unlock_time[vid]) + emit_offload_virtual(t_start, vid, amt) + else: + # 无法腾出足够空间: 退出并尽力而为 + break + + # -------------------------------------------- + # 步骤 C: 执行加载 (Execute Loads) + # -------------------------------------------- + if to_load: + # 按物理地址排序以减少寻道时间 + to_load.sort(key=lambda v: self.v_seg_sub_segments[v][0][0]) + + for vid in to_load: + rem = self.v_seg_total_size[vid] - get_loaded_amount(vid) + if rem > 0: + emit_reload_virtual(T_io, vid, self.v_seg_total_size[vid]) + + # -------------------------------------------- + # 步骤 D: 执行 Visit + # -------------------------------------------- + T_data_ready = T_io + T_visit_start = max(T_npu_ready, start_time, T_data_ready) + T_group_end = T_visit_start + duration + est_group_start_time[g_idx] = T_visit_start + + # 生成 Visit 操作 + for rid in req_ids: + out_buf.append(T_visit_start, "Visit", rid, 0) + + # 锁定所有需要的段直到 NPU 完成 + for vid in needed: + v_seg_unlock_time[vid] = max(v_seg_unlock_time[vid], T_group_end) + + # -------------------------------------------- + # 步骤 E: 预取未来数据(如果时间允许) + # -------------------------------------------- + if g_idx < num_groups - 1 and T_io < T_group_end: + # 查找第一个数据不完整的任务组(屏障) + first_missing = -1 + lookahead = min(num_groups, g_idx + self.LOOKAHEAD_DEPTH) + + for pg in range(g_idx+1, lookahead): + pg_need = self.group_needed_vsegs[pg] + ok = True + for v in pg_need: + if get_loaded_amount(v) < self.v_seg_total_size[v]: + ok = False + break + if not ok: + first_missing = pg + break + + if first_missing == -1: + first_missing = lookahead + 1 + + # 构建带松弛分数的预取候选 + cands = [] + debt = 0 # 累积的先前任务组 I/O 债务 + + for pg in range(g_idx+1, lookahead): + pg_need = self.group_needed_vsegs[pg] + deadline = max(est_group_start_time[pg], T_group_end) + den = group_density_factor[pg] + grp_unloaded = 0 + + for v in pg_need: + rem = self.v_seg_total_size[v] - get_loaded_amount(v) + if rem > 0: + cost = 40 * rem + slack = (deadline - (T_io + cost + debt * 40)) / den + distance = pg - g_idx # 到下次使用的距离 + cands.append((slack, distance, v, rem, pg)) + grp_unloaded += rem + + debt += grp_unloaded + + # 排序候选: Late(紧急)优先,然后 Early (SJF) + late, early = [], [] + for c in cands: + if c[0] < 0: # 负松弛: 紧急 + late.append(c) + else: + early.append(c) + + late.sort(key=lambda x: x[4]) # 按任务组 FCFS + early.sort(key=lambda x: (x[0], x[1])) # SJF 带距离平局 + cands = late + early + + # 执行预取 + for slack, distance, vid, rem, pg in cands: + t_left = T_group_end - T_io + if t_left <= 0: + break + + real_rem = self.v_seg_total_size[vid] - get_loaded_amount(vid) + if real_rem <= 0: + continue + + max_load = t_left // 40 + if max_load <= 0: + break + + amt = min(real_rem, max_load) + + # 保护: 如果内存紧张,不要预取屏障之后的数据 + if pg > first_missing and current_hbm_usage + amt >= self.M: + continue + + # 如果需要,尝试与一个次要的段交换 + if current_hbm_usage + amt > self.M: + deficit = (current_hbm_usage + amt) - self.M + best_s = -1 + best_score = -float('inf') + + # 查找交换受害者(垃圾或远期段) + for cv in loaded_v_seg_indices: + if cv == vid: + continue + if v_seg_unlock_time[cv] > T_io: + continue + + nu = get_next_use(cv, g_idx) + if nu <= pg: # 不要交换目标之前需要的段 + continue + + if nu == INFINITY_USE: + # 找到垃圾: 立即使用 + best_s = cv + break + + # 根据距离和大小评分 + sc = (nu - pg) * self.inv_sqrt_sizes[cv] + if get_loaded_amount(cv) >= deficit and sc > best_score: + best_score = sc + best_s = cv + + # 如果有利可图,执行交换 + if best_s != -1: + swap_amt = min(get_loaded_amount(best_s), deficit) + off_cost = 40 * swap_amt + + if T_io + off_cost + 40 <= T_group_end: + emit_offload_virtual(T_io, best_s, swap_amt) + new_max = (T_group_end - T_io) // 40 + amt = min(real_rem, new_max) + else: + amt = 0 # 交换时间不足 + else: + amt = 0 # 没有合适的受害者 + + # 执行预取 + if amt > 0 and current_hbm_usage + amt <= self.M: + emit_reload_virtual(T_io, vid, get_loaded_amount(vid) + amt) + + # 更新 NPU 时间线 + last_group_compute_end = T_group_end + + # ========================================== + # 阶段 8: 输出结果 (Output Results) + # ========================================== + ops = out_buf.flush() + sys.stdout.write('\n'.join(ops)) + sys.stdout.write(f"\nFin {max(last_group_compute_end, T_io)}\n") + + +# ========================================== +# 策略选择器 (Strategy Selector) +# ========================================== + +class Selector: + """ + 智能策略选择器 - 基于工作负载特征自动选择最优调度策略 + + 分析三大关键指标: + 1. max_p (内存压力): HBM 峰值利用率 + 2. io_density (IO密度): 带宽需求强度 (字节/时间) + 3. avg_seg_size (平均段大小): 内存粒度 + + 选择逻辑: + - 高内存压力 (≥0.95) → 保守策略,避免颠簸 + - 高IO密度 (>1.05) → 激进合并 + 深度预取 + - 大段场景 (M * 0.125) → 合并策略,提升IO效率 + - 默认 → 平衡策略 + """ + + def __init__(self): + """从标准输入读取并分词""" + self.input_data = sys.stdin.read().split() + + @staticmethod + def adjust_lookahead(base_lookahead, num_groups, max_p, io_density): + """ + 综合考虑任务组数量、内存压力和IO密度动态调整预取深度 + + 多维度调整策略: + 1. 基于任务组数量确定基础比例 + 2. 根据内存压力调整(高压→保守,宽裕→激进) + 3. 根据IO密度调整(密集→激进,宽裕→保守) + + 参数: + base_lookahead (int): 策略指定的基础预取深度 + num_groups (int): 任务组总数 + max_p (float): 内存压力(0-1,峰值利用率) + io_density (float): IO密度(相对于基线0.025) + + 返回: + int: 调整后的预取深度 + """ + # 步骤1: 基于任务组数量的基础比例 + if num_groups <= 40: + # 极小规模:基础比例20% + base_ratio = 0.20 + min_val = 8 + elif num_groups <= 100: + # 小规模:基础比例30% + base_ratio = 0.30 + min_val = 10 + elif num_groups <= 300: + # 中规模:基础比例50% + base_ratio = 0.50 + min_val = 15 + else: + # 大规模:使用完整基础值,不再缩放 + return min(base_lookahead, num_groups - 1) + + # 步骤2: 根据内存压力调整 + if max_p >= 0.95: + # 内存高压:更保守(减少20%) + pressure_factor = 0.8 + elif max_p >= 0.85: + # 中等压力:略微保守(减少10%) + pressure_factor = 0.9 + else: + # 内存宽裕:保持原比例 + pressure_factor = 1.0 + + # 步骤3: 根据IO密度调整 + if io_density > 1.2: + # IO密集:更激进(增加30%) + io_factor = 1.3 + elif io_density > 1.05: + # 中等IO:略微激进(增加15%) + io_factor = 1.15 + else: + # IO宽裕:保持原比例 + io_factor = 1.0 + + # 综合调整:基础比例 × 内存因子 × IO因子 + final_ratio = base_ratio * pressure_factor * io_factor + scaled = int(num_groups * final_ratio) + + # 确保在合理范围内 + return min(base_lookahead, max(min_val, scaled)) + + def analyze_and_run(self): + """ + 分析工作负载特征并运行最优规划器 + + 流程: + 1. 解析输入,提取工作负载特征 + 2. 计算三大关键指标 + 3. 基于指标阈值选择最优策略 + 4. 实例化并运行选定的规划器 + """ + if not self.input_data: + return + + # ========================================== + # 步骤 1: 解析输入数据 + # ========================================== + iterator = iter(self.input_data) + try: + L = int(next(iterator)) # 地址空间大小 + M = int(next(iterator)) # HBM 容量 + N = int(next(iterator)) # 请求数量 + + requests = [] + for i in range(N): + requests.append(( + int(next(iterator)), # 地址 + int(next(iterator)), # 大小 + int(next(iterator)), # 开始时间 + int(next(iterator)), # 持续时间 + i # 请求 ID + )) + except StopIteration: + return + + # ========================================== + # 步骤 2: 特征提取 - 平均段大小 + # ========================================== + # 对地址空间进行离散化,计算物理段数量 + points = {0, L} + for r in requests: + points.add(r[0]) + points.add(r[0] + r[1]) + + sorted_points = sorted(list(points)) + temp_cnt = 0 + for i in range(len(sorted_points) - 1): + if sorted_points[i+1] > sorted_points[i]: + temp_cnt += 1 + + avg_seg_size = L / temp_cnt if temp_cnt else 0 + + # ========================================== + # 步骤 3: 特征提取 - 内存压力与 IO 密度 + # ========================================== + # 按开始时间分组请求 + groups = {} + for r in requests: + if r[2] not in groups: + groups[r[2]] = [] + groups[r[2]].append(r) + + sorted_times = sorted(groups.keys()) + + # 内存压力:各时间片的峰值内存占用率 + max_p = 0 + + # IO 密度分析:使用滑动窗口检测局部 IO 峰值 + window_size = 10 # 窗口大小(任务组数量) + window_io_bytes = [] + window_compute_time = [] + + max_window_io_density = 0.0 + total_io_all = 0 + total_dur_all = 0 + + for i, t in enumerate(sorted_times): + reqs = groups[t] + dur = reqs[0][3] + mem = sum(r[1] for r in reqs) + + # 更新峰值内存压力 + max_p = max(max_p, mem / M) + + # 累积窗口数据(IO 成本 = 数据量 × 40 周期/字节) + curr_io = mem * 40 + window_io_bytes.append(curr_io) + window_compute_time.append(dur) + + total_io_all += curr_io + total_dur_all += dur + + # 维护固定窗口大小 + if len(window_io_bytes) > window_size: + window_io_bytes.pop(0) + window_compute_time.pop(0) + + # 计算窗口内的 IO 密度 + win_io_sum = sum(window_io_bytes) + win_dur_sum = sum(window_compute_time) + + if win_dur_sum > 0: + local_density = win_io_sum / win_dur_sum + max_window_io_density = max(max_window_io_density, local_density) + + # 全局平均 IO 密度 + avg_io_density = total_io_all / total_dur_all if total_dur_all > 0 else 0 + + # ========================================== + # 步骤 4: 策略决策 + # ========================================== + + # 计算任务组总数(用于动态调整预取深度) + num_groups = len(sorted_times) + + # 计算实际使用的IO密度(取窗口峰值和全局平均的较大值) + actual_io_density = max(max_window_io_density, avg_io_density) / 0.025 + + # 优先级 1: 内存高压 - 生存第一 + if max_p >= 0.95: + # 关闭合并,使用保守的预取深度,避免内存颠簸 + base_lookahead = 100 + lookahead = self.adjust_lookahead(base_lookahead, num_groups, max_p, actual_io_density) + solver = UniversalPlanner(self.input_data, enable_merge=False, lookahead=lookahead) + + # 优先级 2: IO 瓶颈 - 带宽优先 + elif max_window_io_density > 1.2 or avg_io_density > 1.05: + # 开启合并 + 深度预取,最大化带宽利用 + # UniversalPlanner 会根据每个任务组的密度因子自动调整 + base_lookahead = 200 + lookahead = self.adjust_lookahead(base_lookahead, num_groups, max_p, actual_io_density) + solver = UniversalPlanner(self.input_data, enable_merge=True, lookahead=lookahead) + + # 优先级 3: 大段场景 - IO 效率优先 + # 使用相对于内存大小的阈值(M 的 12.5%)而非固定值 + elif avg_seg_size > M * 0.125: + # 大段数据适合合并加载,使用中等预取深度 + base_lookahead = 100 + lookahead = self.adjust_lookahead(base_lookahead, num_groups, max_p, actual_io_density) + solver = UniversalPlanner(self.input_data, enable_merge=True, lookahead=lookahead) + + # 优先级 4: 默认平衡策略 + else: + # 适用于通用场景,平衡内存与 IO + base_lookahead = 50 + lookahead = self.adjust_lookahead(base_lookahead, num_groups, max_p, actual_io_density) + solver = UniversalPlanner(self.input_data, enable_merge=True, lookahead=lookahead) + + # 执行选定的策略 + solver.run() + + +# ========================================== +# 程序入口 (Entry Point) +# ========================================== + +if __name__ == "__main__": + Selector().analyze_and_run() diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/code/log/code.log" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/code/log/code.log" new file mode 100644 index 0000000000000000000000000000000000000000..c88433bbc0b609820f24d5bbd7ed9f7380a1468d --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/code/log/code.log" @@ -0,0 +1,73 @@ +评测开始于: 2025年11月28日 星期五 08时00分38秒 CST +----------------------------------------- +--- 检验结果: infile_01.txt --- +ok All tasks finish at 48050 +--- 检验结果: infile_02.txt --- +ok All tasks finish at 14020 +--- 检验结果: infile_03.txt --- +ok All tasks finish at 13020 +--- 检验结果: infile_04.txt --- +ok All tasks finish at 102410 +--- 检验结果: infile_05.txt --- +ok All tasks finish at 12040 +--- 检验结果: infile_06.txt --- +ok All tasks finish at 12020 +--- 检验结果: infile_07.txt --- +ok All tasks finish at 13020 +--- 检验结果: infile_08.txt --- +ok All tasks finish at 4030 +--- 检验结果: infile_09.txt --- +ok All tasks finish at 26070 +--- 检验结果: infile_10.txt --- +ok All tasks finish at 20200 +--- 检验结果: infile_11.txt --- +ok All tasks finish at 22050 +--- 检验结果: infile_12.txt --- +ok All tasks finish at 20320 +--- 检验结果: infile_13.txt --- +ok All tasks finish at 12010 +--- 检验结果: infile_14.txt --- +ok All tasks finish at 14100 +--- 检验结果: infile_15.txt --- +ok All tasks finish at 22010 +--- 检验结果: infile_16.txt --- +ok All tasks finish at 28010 +--- 检验结果: infile_17.txt --- +ok All tasks finish at 14100 +--- 检验结果: infile_18.txt --- +ok All tasks finish at 16005 +--- 检验结果: infile_19.txt --- +ok All tasks finish at 16050 +--- 检验结果: infile_20.txt --- +ok All tasks finish at 14050 +--- 检验结果: infile_21.txt --- +ok All tasks finish at 16020 +--- 检验结果: infile_22.txt --- +ok All tasks finish at 14010 +--- 检验结果: infile_23.txt --- +ok All tasks finish at 16120 +--- 检验结果: infile_24.txt --- +ok All tasks finish at 14020 +--- 检验结果: infile_25.txt --- +ok All tasks finish at 16150 +--- 检验结果: infile_26.txt --- +ok All tasks finish at 20050 +--- 检验结果: infile_27.txt --- +ok All tasks finish at 10051 +--- 检验结果: infile_28.txt --- +ok All tasks finish at 8100 +--- 检验结果: infile_29.txt --- +ok All tasks finish at 18010 +--- 检验结果: infile_30.txt --- +ok All tasks finish at 102030 +--- 检验结果: infile_31.txt --- +ok All tasks finish at 10420 +--- 检验结果: infile_32.txt --- +ok All tasks finish at 41670 +--- 检验结果: infile_33.txt --- +ok All tasks finish at 208010 +--- 检验结果: infile_34.txt --- +ok All tasks finish at 100010 +--- 检验结果: infile_35.txt --- +ok All tasks finish at 20010 +----------------------------------------- diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/code/run_all.sh" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/code/run_all.sh" new file mode 100755 index 0000000000000000000000000000000000000000..541b9677badf7ae04a008657348e3cd8d6a6909a --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/code/run_all.sh" @@ -0,0 +1,123 @@ +#!/bin/bash +# +# 批量运行解题程序并使用 checker 进行校验的脚本 +# (v4 - 适配 infile/ 和 outfile/ 子目录) +# + +# --- 1. 配置 (已根据您的路径修改) --- + +# 您的 Python 解释器 +PYTHON_INTERP="python3" + +# 您的解题程序 (Python 脚本) +if [ -n "$1" ]; then + SOLUTION_SCRIPT="$1" +else + SOLUTION_SCRIPT="/Users/chenxi/Documents/比赛/CCF/code/code.py" +fi + +# 存放所有测试用例的目录 +BASE_DIR="/Users/chenxi/Documents/比赛/CCF/checker/example" +INFILE_DIR="${BASE_DIR}/infile" +OUTFILE_DIR="${BASE_DIR}/outfile" + +# 检验结果的日志文件 (将创建在您运行此脚本的当前目录下) +SCRIPT_BASENAME=$(basename "$SOLUTION_SCRIPT" .py) +LOG_FILE="/Users/chenxi/Documents/比赛/CCF/code/log/${SCRIPT_BASENAME}.log" + +# checker 程序的可执行文件路径 +CHECKER_EXEC="/Users/chenxi/Documents/比赛/CCF/checker/checker" + +# --- 2. 检查环境 --- + +# 确保 checker 存在 +if [ ! -f "$CHECKER_EXEC" ]; then + echo "❌ 错误: checker 程序 '$CHECKER_EXEC' 未找到。" + echo " (来自 Makefile 的信息,请确保已编译)" + exit 1 +fi + +# 确保您的解题程序 (Python 脚本) 存在 +if [ ! -f "$SOLUTION_SCRIPT" ]; then + echo "❌ 错误: 您的解题程序 '$SOLUTION_SCRIPT' 未找到。" + exit 1 +fi + +# 确保测试数据输入目录存在 +if [ ! -d "$INFILE_DIR" ]; then + echo "❌ 错误: 测试数据输入目录 '$INFILE_DIR' 未找到。" + echo " 请确保您的 infile 文件存放在该目录下。" + exit 1 +fi + +# 确保测试数据输出目录存在 (如果不存在,则自动创建) +if [ ! -d "$OUTFILE_DIR" ]; then + echo "⚠️ 警告: 输出目录 '$OUTFILE_DIR' 未找到,将自动创建..." + mkdir -p "$OUTFILE_DIR" + if [ $? -ne 0 ]; then + echo "❌ 错误: 无法创建输出目录 '$OUTFILE_DIR'。" + exit 1 + fi +fi + +# 检查 Python 解释器是否存在 +if ! command -v "$PYTHON_INTERP" &> /dev/null; then + echo "❌ 错误: 未找到 Python 解释器 '$PYTHON_INTERP'。" + exit 1 +fi + +# --- 3. 运行和校验 --- + +# 清空之前的日志文件 +echo "评测开始于: $(date)" > "$LOG_FILE" +echo "日志将写入到 $LOG_FILE" +echo "-----------------------------------------" | tee -a "$LOG_FILE" + +# 查找 $INFILE_DIR 目录下的所有 infile*.txt 文件 +shopt -s nullglob +file_count=0 +for local_infile in "$INFILE_DIR"/infile*.txt ; do + + file_count=$((file_count + 1)) + + # --- 关键改动: 构造输出文件路径 --- + # 1. 获取输入文件名 (e.g., "infile_01.txt") + infile_basename=$(basename "$local_infile") + + # 2. 替换 "infile" 为 "outfile" (e.g., "outfile_01.txt") + outfile_basename="${infile_basename/infile/outfile}" + + # 3. 组合成完整的输出路径 + local_outfile="${OUTFILE_DIR}/${outfile_basename}" + + # (日志记录名保持不变) + case_name="$infile_basename" + + echo "▶️ 正在处理: ${case_name}" + + # 1. 运行您的程序,生成 outfile + echo " [步骤 1] 运行您的解题程序 ($PYTHON_INTERP $SOLUTION_SCRIPT)..." + # $local_infile 指向 .../infile/infile_01.txt + # $local_outfile 指向 .../outfile/outfile_01.txt + "$PYTHON_INTERP" "$SOLUTION_SCRIPT" < "$local_infile" > "$local_outfile" + + # 2. 运行 checker 进行检验 + echo " [步骤 2] 运行 checker 进行校验..." + + echo "--- 检验结果: ${case_name} ---" >> "$LOG_FILE" + + # 运行 checker + # 根据 README.md 和 checker.cc,第三个参数未被使用,但必须提供 + "$CHECKER_EXEC" "$local_infile" "$local_outfile" "$local_outfile" >> "$LOG_FILE" 2>&1 + + tail -n 1 "$LOG_FILE" + echo "" + +done + +if [ $file_count -eq 0 ]; then + echo "⚠️ 警告: 在 '$INFILE_DIR' 中未找到任何 'infile*.txt' 文件。" +fi + +echo "-----------------------------------------" | tee -a "$LOG_FILE" +echo "✅ 批量评测完成。详细结果请查看: $LOG_FILE" \ No newline at end of file diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/visual.py" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/visual.py" new file mode 100644 index 0000000000000000000000000000000000000000..386fcf9d895ac0928b3fa0015dac5afd70ac8703 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/visual.py" @@ -0,0 +1,339 @@ +#!/usr/bin/env python3 + +import os +import sys +import glob +import re +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.colors as mcolors +from matplotlib.patches import Rectangle +from pathlib import Path +from typing import Dict, List, Any + +# --- 1. 配置路径 --- +# (无变化) +BASE_DIR = Path("/Users/chenxi/Documents/比赛/CCF/checker/example") +IN_DIR = BASE_DIR / "infile" +OUT_DIR = BASE_DIR / "outfile" +FIG_DIR = BASE_DIR / "fig" + +# --- 2. 辅助函数:解析文件 --- +# (无变化) + +def parse_infile(filepath: Path) -> Dict[str, Any]: + """解析 infile,返回 L, M, N 和一个请求字典。""" + requests = {} + with open(filepath, 'r') as f: + line = f.readline().split() + L, M, N = int(line[0]), int(line[1]), int(line[2]) + + for i in range(N): + line = f.readline().split() + requests[i] = { + 'addr': int(line[0]), + 'size': int(line[1]), + 'start': int(line[2]), + 'time': int(line[3]), + 'id': i + } + return {'L': L, 'M': M, 'N': N, 'requests': requests} + +def parse_outfile(filepath: Path) -> List[Dict[str, Any]]: + """解析 outfile,返回一个操作列表。""" + ops = [] + with open(filepath, 'r') as f: + for line in f: + parts = line.strip().split() + if not parts: + continue + + op_name = parts[0] + T = int(parts[1]) + op = {'op': op_name, 'T': T} + + if op_name in ("Reload", "Offload"): + op['A'] = int(parts[2]) + op['S'] = int(parts[3]) + elif op_name == "Visit": + op['A'] = int(parts[2]) # 这是任务 ID + + ops.append(op) + return ops + +# --- 3. [!!! 核心修改 !!!] 甘特图可视化函数 --- + +def visualize_case(infile_path: Path, outfile_path: Path, fig_path: Path): + """ + 为单个用例生成并保存 甘特图 可视化。 + [!!! MODIFIED !!!] + 图像包含三个子图: + 1. "Requests" (来自 Infile 的理论时间) + 2. "Visits" (NPU 实际调度时间) + 3. "IO" (Reload/Offload 串行时间) + """ + + try: + in_data = parse_infile(infile_path) + out_data = parse_outfile(outfile_path) + except Exception as e: + print(f" [!] 失败: 解析文件时出错 {e}") + return + + L, M, N = in_data['L'], in_data['M'], in_data['N'] + requests = in_data['requests'] + + # 模拟 HBM 状态 (不变) + try: + hbm_state = np.zeros(L, dtype=np.int8) + except MemoryError: + print(f" [!] 失败: 无法为 L={L} 分配内存。跳过...") + return + + # [!!! NEW !!!] 存储 "Requests" (理论) 时间窗口 + request_bars = [] + + visit_bars = [] # 存储 (task_id, T_start, duration, label, color) + io_bars = [] # 存储 (T_start, duration, label, color) + T_fin = 0 + + colors = list(mcolors.TABLEAU_COLORS.values()) + + # [!!! NEW !!!] + # 遍历 Infile requests 来收集 "Requests" 绘图数据 + for task_id, req in requests.items(): + color = colors[task_id % len(colors)] + request_bars.append(( + task_id, + req['start'], # 理论最早开始时间 + req['time'], # 理论执行时间 + f"R{task_id}", # 标签 (Request) + color + )) + # 确保 T_fin 至少和理论时间一样长 + T_fin = max(T_fin, req['start'] + req['time']) + + + # 遍历 Outfile (不变) + for op in out_data: + T_start = op['T'] + T_fin = max(T_fin, T_start) + + if op['op'] == 'Reload': + A, S = op['A'], op['S'] + try: + chunk = hbm_state[A : A+S] + new_bytes = (chunk == 0) + actual_reload_size = int(np.sum(new_bytes)) + chunk[new_bytes] = 1 + except ValueError: + continue + + duration = 40 * actual_reload_size + if duration > 0: + io_bars.append((T_start, duration, 'R', colors[0])) # Blue + T_fin = max(T_fin, T_start + duration) + + elif op['op'] == 'Offload': + A, S = op['A'], op['S'] + try: + chunk = hbm_state[A : A+S] + offload_bytes = (chunk == 1) + actual_offload_size = int(np.sum(offload_bytes)) + chunk[offload_bytes] = 0 + except ValueError: + continue + + duration = 40 * actual_offload_size + if duration > 0: + io_bars.append((T_start, duration, 'O', colors[3])) # Red + T_fin = max(T_fin, T_start + duration) + + elif op['op'] == 'Visit': + task_id = op['A'] + if task_id not in requests: + continue + + req = requests[task_id] + duration = req['time'] + color = colors[task_id % len(colors)] + + visit_bars.append((task_id, T_start, duration, f"V{task_id}", color)) + T_fin = max(T_fin, T_start + duration) + + elif op['op'] == 'Fin': + T_fin = max(T_fin, T_start) + break + + # --- 4. 绘图 --- + + # [!!! MODIFIED !!!] + # 高度比例:为 N 个 Requests, N 个 Visits 和 1 个 IO 任务分配空间 + height_ratio = [max(1, N), max(1, N), 1.5] # Requests, Visits, IO + + fig, (ax_requests, ax_visits, ax_io) = plt.subplots( + 3, 1, # [!!! MODIFIED !!!] 3 个子图 + figsize=(20, 3 + N * 1.0 + 1.0), # [!!! MODIFIED !!!] 动态计算总高度 + sharex=True, + gridspec_kw={'height_ratios': height_ratio} + ) + + # [!!! MODIFIED !!!] 移除图名 + # fig.suptitle(f"Scheduling Gantt Chart: {outfile_path.name}", fontsize=16) + + # [!!! NEW !!!] + # --- 子图 1: "Requests" (Infile 理论时间) --- + ax_requests.set_ylabel("Requests (Input)") + + for (task_id, T_start, duration, label, color) in request_bars: + ax_requests.barh( + y=task_id, + width=duration, + left=T_start, + height=0.8, + edgecolor='black', + alpha=0.4, # 使用较低的 alpha + color=color, + hatch='//' # 使用斜线填充以区分 + ) + # 在条形图内部添加标签 + ax_requests.text( + T_start + duration/2, + task_id, + label, + ha='center', + va='center', + color='black', + alpha=0.7, + fontsize=9 + ) + + # Y 轴设置 (让 Lane 0 在顶部) + ax_requests.set_yticks(range(N)) + ax_requests.set_yticklabels([f"Lane {i}" for i in range(N)]) + ax_requests.set_ylim(N - 0.5, -0.5) + ax_requests.grid(axis='x', linestyle=':', alpha=0.6) + + # --- 子图 2: "Visits" (NPU 实际调度) --- + # [!!! MODIFIED !!!] + ax_visits.set_ylabel("Visits (Scheduled)") + + for (task_id, T_start, duration, label, color) in visit_bars: + ax_visits.barh( + y=task_id, + width=duration, + left=T_start, + height=0.8, + edgecolor='black', + alpha=0.8, # [!!! MODIFIED !!!] 较高的 alpha + color=color + ) + # 在条形图内部添加标签 + ax_visits.text( + T_start + duration/2, + task_id, + label, + ha='center', + va='center', + color='white', + weight='bold', + fontsize=10 + ) + + # Y 轴设置 (让 Lane 0 在顶部) + ax_visits.set_yticks(range(N)) + ax_visits.set_yticklabels([f"Lane {i}" for i in range(N)]) + ax_visits.set_ylim(N - 0.5, -0.5) + ax_visits.grid(axis='x', linestyle=':', alpha=0.6) + + # --- 子图 3: "IO" (Reload/Offload) --- + # (无变化) + ax_io.set_ylabel("IO (Serial)") + + for (T_start, duration, label, color) in io_bars: + ax_io.barh( + y=0, + width=duration, + left=T_start, + height=0.8, + edgecolor='black', + alpha=0.75, + color=color + ) + # 标签 + ax_io.text( + T_start + duration/2, + 0, + label, + ha='center', + va='center', + color='white', + weight='bold', + fontsize=10 + ) + + # Y 轴设置 + ax_io.set_yticks([0]) + ax_io.set_yticklabels(["IO Lane 0"]) + ax_io.set_ylim(-0.5, 0.5) + ax_io.grid(axis='x', linestyle=':', alpha=0.6) + + # --- 5. 保存 --- + ax_io.set_xlabel("Time (T)") + ax_io.set_xlim(left=-T_fin * 0.01, right=T_fin * 1.05) + + # [!!! MODIFIED !!!] 调整 tight_layout 以适应无标题 + plt.tight_layout(rect=[0, 0.03, 1, 1.0]) + plt.savefig(fig_path) + plt.close(fig) + print(f" [✓] 成功: 已保存到 {fig_path.name}") + + +# --- 6. 主程序 --- +# (无变化) +def main(): + print(f"--- 开始可视化 (Gantt 模式) ---") + print(f"输入目录: {IN_DIR}") + print(f"输出目录: {OUT_DIR}") + print(f"图形目录: {FIG_DIR}") + print("--------------------") + + FIG_DIR.mkdir(parents=True, exist_ok=True) + + infile_paths = sorted(glob.glob(str(IN_DIR / "infile_*.txt"))) + + if not infile_paths: + print(f"[!] 错误: 在 {IN_DIR} 中未找到 'infile_*.txt' 文件。") + return + + for infile_path in infile_paths: + infile_path = Path(infile_path) + + match = re.search(r'infile_(\w+)\.txt', infile_path.name) + if not match: + print(f"[!] 跳过: 无法解析文件名 {infile_path.name}") + continue + + file_id = match.group(1) + outfile_name = f"outfile_{file_id}.txt" + outfile_path = OUT_DIR / outfile_name + + fig_name = f"fig_{file_id}_gantt.png" + fig_path = FIG_DIR / fig_name + + print(f"[*] 正在处理: {infile_path.name} -> {outfile_name}") + + if not outfile_path.exists(): + print(f" [!] 跳过: 未找到对应的 {outfile_name}") + continue + + try: + visualize_case(infile_path, outfile_path, fig_path) + except Exception as e: + print(f" [!] 失败: 处理 {infile_path.name} 时发生意外错误: {e}") + + print("--------------------") + print("--- 可视化完成 ---") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/visual_1.py" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/visual_1.py" new file mode 100644 index 0000000000000000000000000000000000000000..6788bb394270670b2de9946dcca7fb90e06c0343 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/visual_1.py" @@ -0,0 +1,473 @@ +#!/usr/bin/env python3 + +import os +import sys +import glob +import re +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.colors as mcolors +from matplotlib.patches import Rectangle +from pathlib import Path +from typing import Dict, List, Any, Tuple + +# --- 1. 配置路径 --- +BASE_DIR = Path("/Users/chenxi/Documents/比赛/CCF/checker/example") +IN_DIR = BASE_DIR / "infile" +OUT_DIR = BASE_DIR / "outfile" +FIG_DIR = BASE_DIR / "fig" + +# --- 2. 辅助函数:解析文件 --- + +def parse_infile(filepath: Path) -> Dict[str, Any]: + """解析 infile,返回 L, M, N 和一个请求字典。""" + requests = {} + with open(filepath, 'r') as f: + line = f.readline().split() + L, M, N = int(line[0]), int(line[1]), int(line[2]) + + for i in range(N): + line = f.readline().split() + requests[i] = { + 'addr': int(line[0]), + 'size': int(line[1]), + 'start': int(line[2]), + 'time': int(line[3]), + 'id': i + } + return {'L': L, 'M': M, 'N': N, 'requests': requests} + +def parse_outfile(filepath: Path) -> List[Dict[str, Any]]: + """解析 outfile,返回一个操作列表。""" + ops = [] + with open(filepath, 'r') as f: + for line in f: + parts = line.strip().split() + if not parts: + continue + + op_name = parts[0] + T = int(parts[1]) + op = {'op': op_name, 'T': T} + + if op_name in ("Reload", "Offload"): + op['A'] = int(parts[2]) + op['S'] = int(parts[3]) + elif op_name == "Visit": + op['A'] = int(parts[2]) # 这是任务 ID + + ops.append(op) + return ops + + +# --- 3. [NEW] 变长时间轴映射函数 --- + +def create_variable_timescale(events: List[Tuple[int, int]], + compression_threshold: int = 1000, + min_segment_display: int = 50) -> Tuple[callable, callable, List[Tuple[int, int]]]: + """ + 创建变长时间轴映射。 + + 参数: + events: [(start_time, end_time), ...] 所有事件的时间区间 + compression_threshold: 超过这个长度的空闲段会被压缩 + min_segment_display: 压缩后空闲段的最小显示宽度 + + 返回: + (real_to_display, display_to_real, segments) + - real_to_display: 真实时间 -> 显示时间的映射函数 + - display_to_real: 显示时间 -> 真实时间的映射函数 + - segments: [(real_start, real_end, display_start, display_end, is_compressed), ...] + """ + if not events: + # 空事件列表,返回恒等映射 + identity = lambda x: x + return identity, identity, [] + + # 1. 收集所有时间点并排序 + time_points = set() + for start, end in events: + time_points.add(start) + time_points.add(end) + + sorted_points = sorted(time_points) + if not sorted_points: + identity = lambda x: x + return identity, identity, [] + + # 2. 识别活跃段和空闲段 + segments = [] + display_cursor = 0 + + for i in range(len(sorted_points)): + real_start = sorted_points[i] + real_end = sorted_points[i+1] if i+1 < len(sorted_points) else sorted_points[i] + + if real_start == real_end: + continue + + # 检查这个时间段是否有活动 + has_activity = any(start <= real_start < end for start, end in events) + + segment_length = real_end - real_start + + if has_activity or segment_length <= compression_threshold: + # 活跃段或短段:正常显示 + display_length = segment_length + is_compressed = False + else: + # 长空闲段:压缩 + display_length = min_segment_display + is_compressed = True + + display_end = display_cursor + display_length + segments.append((real_start, real_end, display_cursor, display_end, is_compressed)) + display_cursor = display_end + + # 3. 创建映射函数 + def real_to_display(t_real: float) -> float: + """真实时间 -> 显示时间""" + for r_start, r_end, d_start, d_end, _ in segments: + if r_start <= t_real <= r_end: + # 线性插值 + ratio = (t_real - r_start) / (r_end - r_start) if r_end > r_start else 0 + return d_start + ratio * (d_end - d_start) + # 超出范围,返回最后的显示时间 + if segments: + return segments[-1][3] + return t_real + + def display_to_real(t_display: float) -> float: + """显示时间 -> 真实时间""" + for r_start, r_end, d_start, d_end, _ in segments: + if d_start <= t_display <= d_end: + ratio = (t_display - d_start) / (d_end - d_start) if d_end > d_start else 0 + return r_start + ratio * (r_end - r_start) + if segments: + return segments[-1][1] + return t_display + + return real_to_display, display_to_real, segments + + +# --- 4. [MODIFIED] 甘特图可视化函数 --- + +def visualize_case(infile_path: Path, outfile_path: Path, fig_path: Path): + """ + 为单个用例生成并保存甘特图可视化。 + [!!! MODIFIED !!!] 使用三个独立的变长时间轴。 + """ + + try: + in_data = parse_infile(infile_path) + out_data = parse_outfile(outfile_path) + except Exception as e: + print(f" [!] 失败: 解析文件时出错 {e}") + return + + L, M, N = in_data['L'], in_data['M'], in_data['N'] + requests = in_data['requests'] + + # 模拟 HBM 状态 + try: + hbm_state = np.zeros(L, dtype=np.int8) + except MemoryError: + print(f" [!] 失败: 无法为 L={L} 分配内存。跳过...") + return + + # 收集三类事件 + request_events = [] # (task_id, real_start, real_duration, label, color) + visit_events = [] + io_events = [] + + colors = list(mcolors.TABLEAU_COLORS.values()) + + # 收集 Requests (理论时间) + for task_id, req in requests.items(): + color = colors[task_id % len(colors)] + request_events.append(( + task_id, + req['start'], + req['time'], + f"R{task_id}", + color + )) + + # 遍历 Outfile 收集 Visits 和 IO + T_fin = 0 + for op in out_data: + T_start = op['T'] + T_fin = max(T_fin, T_start) + + if op['op'] == 'Reload': + A, S = op['A'], op['S'] + try: + chunk = hbm_state[A : A+S] + new_bytes = (chunk == 0) + actual_reload_size = int(np.sum(new_bytes)) + chunk[new_bytes] = 1 + except ValueError: + continue + + duration = 40 * actual_reload_size + if duration > 0: + io_events.append((T_start, duration, 'R', colors[0])) + T_fin = max(T_fin, T_start + duration) + + elif op['op'] == 'Offload': + A, S = op['A'], op['S'] + try: + chunk = hbm_state[A : A+S] + offload_bytes = (chunk == 1) + actual_offload_size = int(np.sum(offload_bytes)) + chunk[offload_bytes] = 0 + except ValueError: + continue + + duration = 40 * actual_offload_size + if duration > 0: + io_events.append((T_start, duration, 'O', colors[3])) + T_fin = max(T_fin, T_start + duration) + + elif op['op'] == 'Visit': + task_id = op['A'] + if task_id not in requests: + continue + + req = requests[task_id] + duration = req['time'] + color = colors[task_id % len(colors)] + + visit_events.append((task_id, T_start, duration, f"V{task_id}", color)) + T_fin = max(T_fin, T_start + duration) + + elif op['op'] == 'Fin': + T_fin = max(T_fin, T_start) + break + + # --- 5. 创建三个独立的变长时间轴 --- + + # Requests 时间轴 + request_intervals = [(start, start + dur) for _, start, dur, _, _ in request_events] + req_r2d, req_d2r, req_segments = create_variable_timescale(request_intervals) + + # Visits 时间轴 + visit_intervals = [(start, start + dur) for _, start, dur, _, _ in visit_events] + vis_r2d, vis_d2r, vis_segments = create_variable_timescale(visit_intervals) + + # IO 时间轴 + io_intervals = [(start, start + dur) for start, dur, _, _ in io_events] + io_r2d, io_d2r, io_segments = create_variable_timescale(io_intervals) + + # --- 6. 转换事件到显示坐标 --- + + request_bars = [] + for task_id, real_start, real_dur, label, color in request_events: + disp_start = req_r2d(real_start) + disp_end = req_r2d(real_start + real_dur) + disp_dur = disp_end - disp_start + request_bars.append((task_id, disp_start, disp_dur, label, color, real_start, real_dur)) + + visit_bars = [] + for task_id, real_start, real_dur, label, color in visit_events: + disp_start = vis_r2d(real_start) + disp_end = vis_r2d(real_start + real_dur) + disp_dur = disp_end - disp_start + visit_bars.append((task_id, disp_start, disp_dur, label, color, real_start, real_dur)) + + io_bars = [] + for real_start, real_dur, label, color in io_events: + disp_start = io_r2d(real_start) + disp_end = io_r2d(real_start + real_dur) + disp_dur = disp_end - disp_start + io_bars.append((disp_start, disp_dur, label, color, real_start, real_dur)) + + # --- 7. 绘图 --- + + height_ratio = [max(1, N), max(1, N), 1.5] + + fig, (ax_requests, ax_visits, ax_io) = plt.subplots( + 3, 1, + figsize=(24, 3 + N * 1.0 + 1.0), + gridspec_kw={'height_ratios': height_ratio} + ) + + # --- 子图 1: Requests --- + ax_requests.set_ylabel("Requests (Input)", fontsize=12, weight='bold') + + for (task_id, disp_start, disp_dur, label, color, real_start, real_dur) in request_bars: + ax_requests.barh( + y=task_id, + width=disp_dur, + left=disp_start, + height=0.8, + edgecolor='black', + alpha=0.4, + color=color, + hatch='//' + ) + # 标签显示真实时间 + ax_requests.text( + disp_start + disp_dur/2, + task_id, + f"{label}\n[{real_start}-{real_start+real_dur}]", + ha='center', + va='center', + color='black', + alpha=0.8, + fontsize=8 + ) + + ax_requests.set_yticks(range(N)) + ax_requests.set_yticklabels([f"R{i}" for i in range(N)]) + ax_requests.set_ylim(N - 0.5, -0.5) + ax_requests.grid(axis='x', linestyle=':', alpha=0.6) + ax_requests.set_title("Requests Timeline (Variable Scale)", fontsize=10, style='italic') + + # 添加压缩段标记 + for r_start, r_end, d_start, d_end, is_compressed in req_segments: + if is_compressed: + ax_requests.axvspan(d_start, d_end, alpha=0.1, color='gray', zorder=0) + + # --- 子图 2: Visits --- + ax_visits.set_ylabel("Visits (Scheduled)", fontsize=12, weight='bold') + + for (task_id, disp_start, disp_dur, label, color, real_start, real_dur) in visit_bars: + ax_visits.barh( + y=task_id, + width=disp_dur, + left=disp_start, + height=0.8, + edgecolor='black', + alpha=0.8, + color=color + ) + ax_visits.text( + disp_start + disp_dur/2, + task_id, + f"{label}\n[{real_start}-{real_start+real_dur}]", + ha='center', + va='center', + color='white', + weight='bold', + fontsize=8 + ) + + ax_visits.set_yticks(range(N)) + ax_visits.set_yticklabels([f"V{i}" for i in range(N)]) + ax_visits.set_ylim(N - 0.5, -0.5) + ax_visits.grid(axis='x', linestyle=':', alpha=0.6) + ax_visits.set_title("Visits Timeline (Variable Scale)", fontsize=10, style='italic') + + # 添加压缩段标记 + for r_start, r_end, d_start, d_end, is_compressed in vis_segments: + if is_compressed: + ax_visits.axvspan(d_start, d_end, alpha=0.1, color='gray', zorder=0) + + # --- 子图 3: IO --- + ax_io.set_ylabel("IO (Serial)", fontsize=12, weight='bold') + + for (disp_start, disp_dur, label, color, real_start, real_dur) in io_bars: + ax_io.barh( + y=0, + width=disp_dur, + left=disp_start, + height=0.8, + edgecolor='black', + alpha=0.75, + color=color + ) + # 标签 + ax_io.text( + disp_start + disp_dur/2, + 0, + f"{label}\n[{real_start}-{real_start+real_dur}]", + ha='center', + va='center', + color='white', + weight='bold', + fontsize=8 + ) + + ax_io.set_yticks([0]) + ax_io.set_yticklabels(["IO Lane"]) + ax_io.set_ylim(-0.5, 0.5) + ax_io.grid(axis='x', linestyle=':', alpha=0.6) + ax_io.set_xlabel("Display Time (Variable Scale)", fontsize=12, weight='bold') + ax_io.set_title("IO Timeline (Variable Scale)", fontsize=10, style='italic') + + # 添加压缩段标记 + for r_start, r_end, d_start, d_end, is_compressed in io_segments: + if is_compressed: + ax_io.axvspan(d_start, d_end, alpha=0.1, color='gray', zorder=0) + + # 设置 X 轴范围(每个子图独立) + if request_bars: + max_disp_req = max(disp_start + disp_dur for _, disp_start, disp_dur, _, _, _, _ in request_bars) + ax_requests.set_xlim(left=0, right=max_disp_req * 1.02) + + if visit_bars: + max_disp_vis = max(disp_start + disp_dur for _, disp_start, disp_dur, _, _, _, _ in visit_bars) + ax_visits.set_xlim(left=0, right=max_disp_vis * 1.02) + + if io_bars: + max_disp_io = max(disp_start + disp_dur for disp_start, disp_dur, _, _, _, _ in io_bars) + ax_io.set_xlim(left=0, right=max_disp_io * 1.02) + + # --- 8. 保存 --- + plt.tight_layout(rect=[0, 0.03, 1, 1.0]) + plt.savefig(fig_path, dpi=150, bbox_inches='tight') + plt.close(fig) + print(f" [✓] 成功: 已保存到 {fig_path.name}") + + +# --- 9. 主程序 --- + +def main(): + print(f"--- 开始可视化 (变长时间轴甘特图) ---") + print(f"输入目录: {IN_DIR}") + print(f"输出目录: {OUT_DIR}") + print(f"图形目录: {FIG_DIR}") + print("--------------------") + + FIG_DIR.mkdir(parents=True, exist_ok=True) + + infile_paths = sorted(glob.glob(str(IN_DIR / "infile_*.txt"))) + + if not infile_paths: + print(f"[!] 错误: 在 {IN_DIR} 中未找到 'infile_*.txt' 文件。") + return + + for infile_path in infile_paths: + infile_path = Path(infile_path) + + match = re.search(r'infile_(\w+)\.txt', infile_path.name) + if not match: + print(f"[!] 跳过: 无法解析文件名 {infile_path.name}") + continue + + file_id = match.group(1) + outfile_name = f"outfile_{file_id}.txt" + outfile_path = OUT_DIR / outfile_name + + fig_name = f"fig_{file_id}_gantt_variable.png" + fig_path = FIG_DIR / fig_name + + print(f"[*] 正在处理: {infile_path.name} -> {outfile_name}") + + if not outfile_path.exists(): + print(f" [!] 跳过: 未找到对应的 {outfile_name}") + continue + + try: + visualize_case(infile_path, outfile_path, fig_path) + except Exception as e: + print(f" [!] 失败: 处理 {infile_path.name} 时发生意外错误: {e}") + import traceback + traceback.print_exc() + + print("--------------------") + print("--- 可视化完成 ---") + +if __name__ == "__main__": + main() diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/src/visual_input.py" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/visual_input.py" new file mode 100644 index 0000000000000000000000000000000000000000..a2320044fe10adbd12531cb83f29a0604f884271 --- /dev/null +++ "b/2025/work/10 \346\234\261\346\231\250\346\233\246/src/visual_input.py" @@ -0,0 +1,238 @@ +#!/usr/bin/env python3 +""" +Input File Gantt Chart Visualization +Visualizes memory access patterns and I/O requirements from input files +""" + +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +from matplotlib.patches import Rectangle +import numpy as np +from pathlib import Path + +def parse_input_file(filepath): + """Parse input file and extract request information""" + with open(filepath, 'r') as f: + first_line = f.readline().split() + L = int(first_line[0]) # Virtual address space + M = int(first_line[1]) # HBM capacity + N = int(first_line[2]) # Number of requests + + requests = [] + for i in range(N): + line = f.readline().split() + requests.append({ + 'id': i, + 'addr': int(line[0]), + 'size': int(line[1]), + 'start': int(line[2]), + 'duration': int(line[3]) + }) + + return L, M, N, requests + + +def visualize_input_gantt(filepath, output_path): + """ + Create a Gantt chart showing: + 1. Request computation time + 2. Required data loading time (theoretical) + 3. Memory access patterns + """ + + # Parse input + L, M, N, requests = parse_input_file(filepath) + + # Calculate metrics + IO_CYCLES_PER_BYTE = 40 + + # Create figure with 3 subplots (increased spacing) + fig, (ax_compute, ax_load, ax_memory) = plt.subplots( + 3, 1, + figsize=(16, 12), + gridspec_kw={'height_ratios': [2, 2, 1.5], 'hspace': 0.4} + ) + + # Color scheme + colors = plt.cm.Set3(np.linspace(0, 1, N)) + + # ========================================== + # Subplot 1: Computation Time (NPU) + # ========================================== + ax_compute.set_title( + f'Visit Time\n' + f'L: {L:,} | M: {M:,} | N: {N}', + fontsize=14, weight='bold', pad=20 + ) + ax_compute.set_ylabel('Request', fontsize=12, weight='bold') + ax_compute.set_xlabel('Time', fontsize=12, weight='bold') + + for req in requests: + ax_compute.barh( + y=req['id'], + width=req['duration'], + left=req['start'], + height=0.6, + color=colors[req['id']], + edgecolor='black', + linewidth=1.5, + alpha=0.8, + label=f"R{req['id']}" + ) + + ax_compute.set_yticks(range(N)) + ax_compute.set_yticklabels([f'R{i}' for i in range(N)]) + ax_compute.invert_yaxis() + ax_compute.grid(axis='x', linestyle=':', alpha=0.6) + ax_compute.set_xlim(left=0, right=max(r['start'] + r['duration'] for r in requests) * 1.1) + + # ========================================== + # Subplot 2: Required I/O Time (Theoretical) + # ========================================== + ax_load.set_title( + 'I/O Time', + fontsize=14, weight='bold', pad=15 + ) + ax_load.set_ylabel('Request', fontsize=12, weight='bold') + ax_load.set_xlabel('Time', fontsize=12, weight='bold') + + for req in requests: + load_time = req['size'] * IO_CYCLES_PER_BYTE + + ax_load.barh( + y=req['id'], + width=load_time, + left=req['start'], + height=0.6, + color=colors[req['id']], + edgecolor='darkred', + linewidth=2, + alpha=0.7, + hatch='//' + ) + + ax_load.set_yticks(range(N)) + ax_load.set_yticklabels([f'R{i}' for i in range(N)]) + ax_load.invert_yaxis() + ax_load.grid(axis='x', linestyle=':', alpha=0.6) + + # Calculate max time considering I/O + max_io_time = max(r['start'] + r['size'] * IO_CYCLES_PER_BYTE for r in requests) + ax_load.set_xlim(left=0, right=max_io_time * 1.1) + + # ========================================== + # Subplot 3: Memory Address Space Usage + # ========================================== + ax_memory.set_title( + f'L: {L:,} | M: {M:,} | N: {N}', + fontsize=14, weight='bold', pad=15 + ) + ax_memory.set_ylabel('Request', fontsize=12, weight='bold') + ax_memory.set_xlabel('Memory Address', fontsize=12, weight='bold') + + for req in requests: + ax_memory.barh( + y=req['id'], + width=req['size'], + left=req['addr'], + height=0.6, + color=colors[req['id']], + edgecolor='black', + linewidth=1.5, + alpha=0.6 + ) + + # Add HBM capacity line + ax_memory.axvline(x=M, color='red', linestyle='--', linewidth=2, label=f'HBM Capacity ({M:,})') + ax_memory.axvline(x=L, color='blue', linestyle='--', linewidth=2, label=f'Address Space ({L:,})') + + ax_memory.set_yticks(range(N)) + ax_memory.set_yticklabels([f'R{i}' for i in range(N)]) + ax_memory.invert_yaxis() + ax_memory.grid(axis='x', linestyle=':', alpha=0.6) + ax_memory.set_xlim(left=0, right=L * 1.05) + ax_memory.legend(loc='upper right', fontsize=10) + + # ========================================== + # Summary Statistics (Text Box) + # ========================================== + + # Calculate statistics + total_compute_time = sum(r['duration'] for r in requests) + total_data_size = sum(r['size'] for r in requests) + unique_bytes = set() + for req in requests: + for addr in range(req['addr'], req['addr'] + req['size']): + unique_bytes.add(addr) + unique_data_size = len(unique_bytes) + theoretical_io_time = unique_data_size * IO_CYCLES_PER_BYTE + + io_density = unique_data_size / total_compute_time if total_compute_time > 0 else 0 + memory_pressure = unique_data_size / M + time_pressure = theoretical_io_time / total_compute_time if total_compute_time > 0 else 0 + + # stats_text = ( + # f"═══ Workload Statistics ═══\n" + # f"Total Compute Time: {total_compute_time:,} cycles\n" + # f"Total Data Requested: {total_data_size:,} bytes\n" + # f"Unique Data: {unique_data_size:,} bytes\n" + # f"Theoretical I/O Time: {theoretical_io_time:,} cycles\n" + # f"\n" + # f"I/O Density Ratio: {io_density / 0.025:.2f}x\n" + # f"Memory Pressure: {memory_pressure:.2f}x (HBM)\n" + # f"Time Pressure: {time_pressure:.1%}\n" + # f"\n" + # f"Workload Type: " + # f"{'IO-Intensive' if io_density / 0.025 > 1.05 else 'Compute-Intensive'} + " + # f"{'Memory-Intensive' if memory_pressure >= 0.95 else 'Memory-Sufficient'}" + # ) + + # fig.text( + # 0.02, 0.02, stats_text, + # fontsize=10, + # family='monospace', + # bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5), + # verticalalignment='bottom' + # ) + + # ========================================== + # Layout and Save + # ========================================== + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches='tight') + plt.close() + + print(f"✓ Gantt chart saved to: {output_path}") + + + +def main(): + """Main function to process input files""" + + # Default paths + input_file = Path("/Users/chenxi/Documents/比赛/CCF/checker/example/infile/infile_32.txt") + output_file = Path("/Users/chenxi/Documents/比赛/CCF/checker/example/fig/gantt_input_32.png") + + # Ensure output directory exists + output_file.parent.mkdir(parents=True, exist_ok=True) + + print("=" * 60) + print("Input File Gantt Chart Visualization") + print("=" * 60) + print(f"Input: {input_file}") + print(f"Output: {output_file}") + print("-" * 60) + + if not input_file.exists(): + print(f"Error: Input file not found: {input_file}") + return + + visualize_input_gantt(input_file, output_file) + + print("=" * 60) + print("Visualization complete!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git "a/2025/work/10 \346\234\261\346\231\250\346\233\246/\350\207\252\351\200\202\345\272\224\347\255\276\345\220\215\351\251\261\345\212\250\347\232\204\350\231\232\346\213\237\346\256\265\350\260\203\345\272\246\347\256\227\346\263\225.pdf" "b/2025/work/10 \346\234\261\346\231\250\346\233\246/\350\207\252\351\200\202\345\272\224\347\255\276\345\220\215\351\251\261\345\212\250\347\232\204\350\231\232\346\213\237\346\256\265\350\260\203\345\272\246\347\256\227\346\263\225.pdf" new file mode 100644 index 0000000000000000000000000000000000000000..47be9251414faeea05b291d439dd21192d23a1b7 Binary files /dev/null and "b/2025/work/10 \346\234\261\346\231\250\346\233\246/\350\207\252\351\200\202\345\272\224\347\255\276\345\220\215\351\251\261\345\212\250\347\232\204\350\231\232\346\213\237\346\256\265\350\260\203\345\272\246\347\256\227\346\263\225.pdf" differ