前言
餐馆
思路:可撤销的0-1背包
考察了多个知识点,包括
- 差分技巧
- 离线思路
- 0-1背包
不过这题卡语言,尤其卡python
import java.io.*;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
public class Main {
static final long mod = (long)1e9 + 7;
public static void main(String[] args) {
AReader scanner = new AReader();
PrintWriter out = new PrintWriter(new BufferedOutputStream(System.out));
// 读取n和v
int n = scanner.nextInt();
int v = scanner.nextInt();
// 读取并处理物品信息
List<int[]> packs = new ArrayList<>();
List<int[]> ops = new ArrayList<>();
for (int i = 0; i < n; i++) {
int s = scanner.nextInt();
int e = scanner.nextInt();
int w = scanner.nextInt();
packs.add(new int[]{s, e, w});
ops.add(new int[]{s, 1, w});
ops.add(new int[]{e + 1, -1, w});
}
// 对ops按时间排序
ops.sort(Comparator.comparingInt(a -> a[0]));
// 读取q和查询时间
int q = scanner.nextInt();
int[] arr = IntStream.range(0, q).map(i -> scanner.nextInt()).toArray();
// 将查询和索引关联起来
List<int[]> qs = IntStream.range(0, q).mapToObj(i -> new int[]{arr[i], i}).collect(Collectors.toList());
qs.sort(Comparator.comparingInt(a -> a[0]));
// 互斥的两类
long[] dp1 = new long[v + 1];
long[] dp2 = new long[v + 1];
// 初始化dp2
dp2[0] = 1;
for (int[] pack : packs) {
int s = pack[0];
int w = pack[2];
for (int m = v - w; m >= 0; m--) {
dp2[m + w] += dp2[m];
dp2[m + w] %= mod;
}
}
dp1[0] = 1;
// 双指针,离散做法
int p1 = 0;
int p2 = 0;
int[][] res = new int[q][2];
for (int i = 0; i < 101; i++) { // 假设结束时间不超过arr[q-1]
while (p1 < ops.size() && ops.get(p1)[0] <= i) {
int[] op = ops.get(p1);
int d = op[1];
int w = op[2];
if (d == 1) {
add01(dp1, w, v);
remove01(dp2, w, v);
} else {
remove01(dp1, w, v);
add01(dp2, w, v);
}
p1++;
}
// 找到fz和bz
int fz = 0;
for (int j = v; j >= 0; j--) {
if (dp1[j] > 0) {
fz = j;
break;
}
}
int bz = 0;
for (int j = v - fz; j >= 0; j--) {
if (dp2[j] > 0) {
bz = j;
break;
}
}
// 填充结果
while (p2 < qs.size() && qs.get(p2)[0] <= i) {
res[qs.get(p2)[1]][0] = fz;
res[qs.get(p2)[1]][1] = bz;
p2++;
}
}
// 输出结果
for (int[] pair : res) {
out.println(pair[0] + " " + pair[1]);
}
out.flush();
out.close();
}
private static void add01(long[] dp, int w, int v) {
for (int u = v - w; u >= 0; u--) {
dp[u + w] += dp[u];
dp[u + w] %= mod;
}
}
private static void remove01(long[] dp, int w, int v) {
for (int u = 0; u <= v - w; u++) {
dp[u + w] -= dp[u];
dp[u + w] = (dp[u + w] % mod + mod) % mod;
// 注意:在Java中,如果dp[u]是0或负数,可能需要额外的逻辑来避免负数
}
}
static
class AReader {
private BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
private StringTokenizer tokenizer = new StringTokenizer("");
private String innerNextLine() {
try {
return reader.readLine();
} catch (IOException ex) {
return null;
}
}
public boolean hasNext() {
while (!tokenizer.hasMoreTokens()) {
String nextLine = innerNextLine();
if (nextLine == null) {
return false;
}
tokenizer = new StringTokenizer(nextLine);
}
return true;
}
public String nextLine() {
tokenizer = new StringTokenizer("");
return innerNextLine();
}
public String next() {
hasNext();
return tokenizer.nextToken();
}
public int nextInt() {
return Integer.parseInt(next());
}
public long nextLong() {
return Long.parseLong(next());
}
}
}
python 版本被卡常
# coding=utf-8
# coding=utf-8
import sys
input=sys.stdin.buffer.readline
n, v = list(map(int, input().split()))
ops = []
packs = []
for i in range(n):
s, e, w = list(map(int, input().split()))
packs.append((s, e, w))
ops.append((s, 1, w))
ops.append((e + 1, -1, w))
ops.sort(key=lambda x: x[0])
t1, t2 = 100, 0
q = int(input())
arr = list(map(int, input().split()))
qs = []
for i in range(q):
qs.append((arr[i], i))
t1 = min(t1, arr[i])
t2 = max(t2, arr[i])
qs.sort(key=lambda x: [0])
# 互斥的两类
dp1 = [0] * (v + 1)
dp2 = [0] * (v + 1)
#--------------------------------
dp2[0] = 1
for (s, e, w) in packs:
for m in range(v - w, -1, -1):
dp2[m + w] += dp2[m]
dp1[0] = 1
def add01(dp, w):
for u in range(v - w, -1, -1):
dp[u + w] += dp[u]
def remove01(dp, w):
for u in range(0, v - w + 1):
dp[u + w] -= dp[u]
# 双指针,离散做法
res = [[]] * q
p1, p2 = 0, 0
for i in range(t1, t2 + 1):
while p1 < len(ops) and ops[p1][0] <= i:
d = ops[p1][1]
if d == 1:
add01(dp1, ops[p1][2])
remove01(dp2, ops[p1][2])
elif d == -1:
remove01(dp1, ops[p1][2])
add01(dp2, ops[p1][2])
p1 += 1
# print (i, dp1, dp2)
fz = max([i for i in range(v + 1) if dp1[i] > 0])
bz = max([i for i in range(v + 1) if dp2[i] > 0 and i <= (v - fz)])
while p2 < len(qs) and qs[p2][0] <= i:
res[qs[p2][1]] = (fz, bz)
p2 += 1
# for (fz, bz) in res:
# print (fz, bz)
print ("\n".join(map(lambda x: str(x[0]) + " "+ str(x[1]), res)))