...
 
Commits (3)
import gym
import gym_shopping_cart
from gym_shopping_cart.data import InstacartData
import numpy as np
env = gym.make("ShoppingCart-v0")
......@@ -8,7 +7,7 @@ episode_over = False
chicken_id = 6046
print(env.data.columns()[chicken_id])
print(env.data.product_str(chicken_id))
action = np.zeros(InstacartData.N_PRODUCTS)
action = np.zeros(env.data.n_products())
action[chicken_id] = 1
rewards = 0
while not episode_over:
......
......@@ -2,6 +2,5 @@ from gym.envs.registration import register
register(id="ShoppingCart-v0", entry_point="gym_shopping_cart.envs:ShoppingCart")
register(
id="SimplifiedShoppingCart-v0",
entry_point="gym_shopping_cart.envs:SimplifiedShoppingCart",
id="SimpleShoppingCart-v0", entry_point="gym_shopping_cart.envs:SimpleShoppingCart"
)
......@@ -110,6 +110,7 @@ class InstacartData:
res = pd.concat(
[encoded_products, encoded_dow, encoded_hod, encoded_days_since], axis=1
)
res = res.fillna(0)
return res.sort_index()
def _raw_orders_for_user(self, id: np.uint32 = None) -> pd.DataFrame:
......
from gym_shopping_cart.envs.shopping_cart_v0 import ShoppingCart
from gym_shopping_cart.envs.shopping_cart_v0 import SimplifiedShoppingCart
from gym_shopping_cart.envs.shopping_cart_v0 import SimpleShoppingCart
......@@ -101,7 +101,7 @@ class ShoppingCart(gym.Env):
pass
class SimplifiedShoppingCart(ShoppingCart):
class SimpleShoppingCart(ShoppingCart):
"""
Exactly the same as ShoppingCart except I limit the number of products to the 20 most popular
"""
......
......@@ -30,7 +30,7 @@ def test_parse_instacart_data():
assert res.loc[33].shape[0] == InstacartData.N_OBSERVATIONS
assert res.loc[33]["order_dow_3"] == 1
assert res.loc[33]["order_hour_of_day_12"] == 1
assert np.isnan(res.loc[1]["days_since_prior_order"])
assert res.loc[1]["days_since_prior_order"] == 0 # NOT nan.
assert res.loc[2].shape[0] == InstacartData.N_OBSERVATIONS
assert res.loc[2]["order_dow_1"] == 1
assert res.loc[2]["product_id_9637"] == 1
......
......@@ -5,7 +5,7 @@ import gym_shopping_cart
def test_simplified_shopping_cart():
env = gym.make("SimplifiedShoppingCart-v0")
env = gym.make("SimpleShoppingCart-v0")
state, _, _, _ = env.step(env.action_space.sample())
assert isinstance(state, np.ndarray)
assert state.shape[0] == 52