Skip to content

Latest commit

 

History

History
245 lines (191 loc) · 4.76 KB

0065.买瓜.md

File metadata and controls

245 lines (191 loc) · 4.76 KB

65. 买瓜

题目链接

C

C++

#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>

#define x first
#define y second

using namespace std;

typedef long long LL;
typedef pair<int, int> PII;

const int N = 35;

int n, m, tn, t;
int w[N];
int res = 50;
PII alls[1 << 21];

void dfs_1(int u, LL s, int cnt)
{
    if (s > m)
        return;
    if (u == tn)
    {
        alls[t ++] = {s, cnt};
        return;
    }
    
    dfs_1(u + 1, s, cnt);
    dfs_1(u + 1, s + w[u], cnt);
    dfs_1(u + 1, s + w[u] / 2, cnt + 1);
}

void dfs_2(int u, LL s, int cnt)
{
    if (s > m || cnt >= res)
        return;
    if (u == n)
    {   
        int l = 0, r = t - 1;
        while (l < r)
        {
            int mid = l + r >> 1;
            if (alls[mid].x + s >= m)
                r = mid;
            else
                l = mid + 1;
        }
        
        if (alls[l].x + s == m)
            res = min(res, alls[l].y + cnt);
        
        return;
    }
    
    dfs_2(u + 1, s, cnt);
    dfs_2(u + 1, s + w[u], cnt);
    dfs_2(u + 1, s + w[u] / 2, cnt + 1);
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    
    cin >> n >> m, m *= 2;
    for (int i = 0; i < n; ++ i )
        cin >> w[i], w[i] *= 2;
    
    sort(w, w + n);
    
    tn = max(0, n / 2 - 2);
    
    dfs_1(0, 0, 0);
    
    sort(alls, alls + t);
    int k = 1;
    for (int i = 1; i < t; ++ i )
        if (alls[i].x != alls[i - 1].x)
            alls[k ++] = alls[i];
    t = k;
    
    dfs_2(tn, 0, 0);
    
    cout << (res == 50? -1: res) << endl;
    
    return 0;
}

Java

import java.io.*;
import java.util.*;

public class Main {
    private static final int N = 35;

    private int n, m, tn, t;
    private int[] w = new int[N];
    private int res = 50;
    private Pair[] alls = new Pair[1 << 21];

    public static void main(String[] args) throws IOException {
        Main main = new Main();
        main.solve();
    }

    private void solve() throws IOException {
        BufferedReader in = new BufferedReader(new InputStreamReader(System.in));
        PrintWriter out = new PrintWriter(new BufferedWriter(new OutputStreamWriter(System.out)));

        String[] input = in.readLine().split(" ");
        n = Integer.parseInt(input[0]);
        m = Integer.parseInt(input[1]) * 2;

        input = in.readLine().split(" ");
        for (int i = 0; i < n; ++i) {
            w[i] = Integer.parseInt(input[i]) * 2;
        }

        Arrays.sort(w, 0, n);

        tn = Math.max(0, n / 2 - 2);

        dfs_1(0, 0, 0);

        Arrays.sort(alls, 0, t, Comparator.comparingInt(Pair::getX));

        int k = 1;
        for (int i = 1; i < t; ++i) {
            if (alls[i].x != alls[i - 1].x) {
                alls[k++] = alls[i];
            }
        }
        t = k;

        dfs_2(tn, 0, 0);

        out.println(res == 50 ? -1 : res);
        out.flush();
    }

    private void dfs_1(int u, long s, int cnt) {
        if (s > m)
            return;
        if (u == tn) {
            alls[t++] = new Pair((int) s, cnt);
            return;
        }

        dfs_1(u + 1, s, cnt);
        dfs_1(u + 1, s + w[u], cnt);
        dfs_1(u + 1, s + w[u] / 2, cnt + 1);
    }

    private void dfs_2(int u, long s, int cnt) {
        if (s > m || cnt >= res)
            return;
        if (u == n) {
            int l = 0, r = t - 1;
            while (l < r) {
                int mid = l + r >> 1;
                if (alls[mid].x + s >= m)
                    r = mid;
                else
                    l = mid + 1;
            }

            if (alls[l].x + s == m)
                res = Math.min(res, alls[l].y + cnt);

            return;
        }

        dfs_2(u + 1, s, cnt);
        dfs_2(u + 1, s + w[u], cnt);
        dfs_2(u + 1, s + w[u] / 2, cnt + 1);
    }

    static class Pair {
        int x, y;

        Pair(int x, int y) {
            this.x = x;
            this.y = y;
        }

        int getX() {
            return x;
        }
    }
}

Python

length, target = map(int, input().split())
weights = list(map(int, input().split()))
ans = 100 

def dfs(i,cur_time,cur_weight,length,target):
    global ans 
    # 剪枝
    if cur_time >= ans:
        return 
    if cur_weight > target:
        return 
    if cur_weight + sum(weights[i:]) < target:
        return 
    # 结束条件
    if cur_weight == target:
        ans = min(ans,cur_time)
    if i >= length:
        return 
    # 全选i
    dfs(i+1,cur_time,cur_weight+weights[i],length,target)
    # 选一半i
    dfs(i+1,cur_time+1,cur_weight+weights[i]/2,length,target)
    # 不选i
    dfs(i+1,cur_time,cur_weight,length,target)
    return 

dfs(0,0,0,length,target)
if ans == 100: ans = -1
print(ans)

JS

Go