Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Failing to call CVRPEnv.local_search #212

Closed
2 tasks done
ShuN6211 opened this issue Sep 4, 2024 · 4 comments
Closed
2 tasks done

[BUG] Failing to call CVRPEnv.local_search #212

ShuN6211 opened this issue Sep 4, 2024 · 4 comments
Assignees
Labels
bug Something isn't working

Comments

@ShuN6211
Copy link
Contributor

ShuN6211 commented Sep 4, 2024

Describe the bug

Calling CVRPEnv.local_search raises error and it fails.

To Reproduce

import torch
from rl4co.models.zoo import AttentionModelPolicy
from rl4co.envs import CVRPEnv

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
trained_model = AttentionModel.load_from_checkpoint("path/to/checkpoint")
env: CVRPEnv = trained_model.env
trained_policy = trained_model.policy.to(device)
td_init = env.reset(batch_size=[1]).to(device)

out_trained = trained_policy(
    td_init.clone(), phase="test", decode_type="greedy", return_actions=True
)

improved_actions = env.local_search(td_init, out_trained["actions"]) # raises error!!
Traceback (most recent call last):
  "name": "TypeError",
	"message": "__init__(): incompatible constructor arguments. The following argument types are supported:\n    1. pyvrp._pyvrp.ProblemData(clients: List[pyvrp._pyvrp.Client], depots: List[pyvrp._pyvrp.Depot], vehicle_types: List[pyvrp._pyvrp.VehicleType], distance_matrices: List[numpy.ndarray[int]], duration_matrices: List[numpy.ndarray[int]], groups: List[pyvrp._pyvrp.ClientGroup] = [])\n\nInvoked with: kwargs:

System info

INSTALLED VERSIONS
-------------------------------------
            rl4co : 0.5.0.dev2
            torch : 2.4.0
        lightning : 2.4.0
          torchrl : 0.5.0
       tensordict : 0.5.0
            numpy : 1.26.4
pytorch_geometric : 2.5.3
       hydra-core : 1.3.2
        omegaconf : 2.3.0
       matplotlib : 3.9.2
           Python : 3.11.9
         Platform : macOS-14.4.1-arm64-arm-64bit
 Lightning device : mps

Reason and Possible fixes

#211

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have provided a minimal working example to reproduce the bug (required)
@fedebotu
Copy link
Member

fedebotu commented Sep 4, 2024

Hi @ShuN6211 !

Thanks a lot, I pushed some hotfixes :)

%load_ext autoreload
%autoreload 2

import torch
from rl4co.models.zoo import AttentionModelPolicy
from rl4co.envs import CVRPEnv

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize environment and policy
env = CVRPEnv(generator_params={'num_loc': 50})
policy = AttentionModelPolicy(env_name=env.name).to(device)
td_init = env.reset(batch_size=[8]).to(device)

# Rollout policy (untrained here, you may load a pre-trained model first)
out = policy(
    td_init.clone(), phase="test", decode_type="greedy", return_actions=True
)

# Get initial rewards
rewards = env.get_reward(td_init, out["actions"])
print(f"Rewards: {rewards.mean():.3f}")

# Improve actions using local search
improved_actions = env.local_search(td_init.cpu(), out["actions"].cpu()) # raises error!!
rewards = env.get_reward(td_init.cpu(), improved_actions)
print(f"Rewards: {rewards.mean():.3f}")

The above should work!

@ShuN6211
Copy link
Contributor Author

ShuN6211 commented Sep 4, 2024

@fedebotu Thanks for fixing not only local_search for CVRP but also other modules and version constraint!! So, I think this issue and #211 can be closed.

@fedebotu
Copy link
Member

fedebotu commented Sep 4, 2024

Yup! I realized late that you had also opened a PR, I went straight for the fix 🤣 thanks~

@ShuN6211
Copy link
Contributor Author

ShuN6211 commented Sep 4, 2024

Not a problem🤣 Your work was so quick and just be impressed!! thanks!!

@ShuN6211 ShuN6211 closed this as completed Sep 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants