8.2.7. Example: Recommendation using NCF
Sample program demonstrating training of the Neural Collaborative Filtering (NCF) model on the MovieLens 1M Dataset
Note
The trained model and some source code used in this example have been partially modified or directly sourced from mlcommons/training. All of these components are licensed under the Apache License, Version 2.0.
Execution Method
The first execution performs the following downloads (subsequent runs will skip these steps).
By default, the download location is /tmp/ncf_training/.
Python packages
$ cd /opt/pfn/pfcomp/codegen/MLSDK/examples/ncf_training
$ ./preparation.sh
$ ./ncf_training.sh --option_json ../../../preset_options/debug.json
Output
The evaluation results are output in the following format. They shows the hit rate (HR) and the normalized documented cumulative gain (NDCG).
HR@10 = 0.1628, NDCG@10 = 0.0808
Scripts
1import argparse
2import os
3from pathlib import Path
4
5import numpy as np
6import torch
7import torch.nn as nn
8from alias_generator import AliasSample
9from mlsdk import Context, MNCoreAdam, MNCoreOptimizer, MNDevice
10from ncf_eval import run_eval
11from ncf_utils import (
12 generate_neg_dataset,
13 generate_padding,
14 load_eval_data,
15 load_model,
16 load_train_pos_data,
17 save_model,
18)
19
20# import from externals/mlcommons-ncf/recommendation/pytorch
21from neumf import NeuMF
22from torch.utils.data import ConcatDataset, DataLoader
23from utility import apply_toml_defaults, compile_fn, set_deterministic_mode
24
25
26def run_train( # noqa: CFQ002
27 args: argparse.Namespace,
28 model: nn.Module,
29 device: str,
30 context: Context,
31 optimizer: MNCoreOptimizer,
32 loss_fn: nn.BCEWithLogitsLoss,
33 outdir: str,
34 pos_dataset: torch.utils.data.Dataset,
35 neg_sampler: AliasSample,
36 num_items: np.int64,
37) -> None:
38 # Define training functions
39 def train_fx2onnx(sample_d: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
40 optimizer.zero_grad()
41
42 outputs = model(sample_d["user"], sample_d["item"])
43 loss = loss_fn(outputs, sample_d["label"]).float()
44 loss = torch.mean(loss.view(-1), 0)
45
46 loss.backward()
47 optimizer.step()
48 return {"result": loss}
49
50 # Dummy input for compilation
51 samples = {
52 "user": torch.randint(1, (args.train_batch_size,), dtype=torch.int64),
53 "item": torch.randint(1, (args.train_batch_size,), dtype=torch.int64),
54 "label": torch.rand(args.train_batch_size).view(-1, 1),
55 }
56
57 train_fn = compile_fn(
58 context,
59 train_fx2onnx,
60 model,
61 samples,
62 outdir=outdir,
63 model_name="ncf_train",
64 is_train=True,
65 optimizers=[optimizer],
66 option_json=str(args.option_json),
67 )
68
69 model.train()
70
71 total_iter = (
72 len(pos_dataset) * (1 + args.train_neg_ratio)
73 ) // args.train_batch_size
74
75 for epoch in range(args.epoch):
76 pos_users, _, pos_labels = pos_dataset.tensors
77 neg_dataset = generate_neg_dataset(
78 pos_users,
79 pos_labels.size(),
80 args.train_neg_ratio,
81 num_items,
82 args.allow_collision_with_pos,
83 neg_sampler,
84 )
85 dataloader = DataLoader(
86 ConcatDataset([pos_dataset, neg_dataset]),
87 batch_size=args.train_batch_size,
88 shuffle=True,
89 num_workers=args.loader_num_workers,
90 drop_last=True,
91 )
92
93 for num_batch, (users, items, labels) in enumerate(dataloader):
94 samples["user"] = users
95 samples["item"] = items
96 samples["label"] = labels.view(-1, 1)
97
98 output = train_fn(samples)
99 epoch_str = f"Epoch {epoch + 1}/{args.epoch}"
100 iteration_str = f"Iteration {num_batch + 1}/{total_iter}"
101 loss_str = f"Loss: {output["result"].item():.4}"
102 print(f"{epoch_str}, {iteration_str}, {loss_str}")
103
104 # Synchronize tensors on MN-Core 2"s DRAM and PyTorch tensors
105 context.synchronize()
106
107
108def main(args: argparse.Namespace) -> None: # noqa: CFQ001
109
110 if args.save_path != "":
111 dir_path = os.path.dirname(args.save_path) or "."
112 if not os.access(dir_path, os.W_OK):
113 raise ValueError("Parent directory of save_path is not writable.")
114
115 # Fix seed values for reproducibility
116 set_deterministic_mode(args.seed)
117
118 # Decide device and outdir from given options
119 device_name = args.device
120 outdir = args.outdir
121
122 # Load positive data for training and create dataset
123 data_dir = f"/tmp/ncf_training/{args.dataset}"
124 scaled_data_dir = (
125 f"{data_dir}/{args.dataset}x{args.user_scaling}x{args.item_scaling}"
126 )
127 train_pos_dataset, num_users, num_items, neg_sampler = load_train_pos_data(
128 scaled_data_dir, args.user_scaling, args.item_scaling
129 )
130
131 # Define model
132 model = NeuMF(
133 num_users,
134 num_items,
135 mf_dim=args.factors,
136 mf_reg=0.0,
137 mlp_layer_sizes=args.mlp_layers,
138 mlp_layer_regs=([0.0] * len(args.mlp_layers)),
139 )
140
141 # Create loss function object
142 loss_fn = nn.BCEWithLogitsLoss(reduction="none")
143
144 # Create optimizer object
145 optimizer = MNCoreAdam(
146 model.parameters(),
147 lr=args.learning_rate,
148 chainer_use_torch=True,
149 )
150
151 # Pass the device information to context or move the model and optimizer to specified device
152 device = MNDevice(device_name)
153 train_context = Context(device)
154 eval_context = Context(device)
155 Context.switch_context(train_context)
156
157 # Load pre-trained model
158 if args.load_path != "":
159 load_model(model, optimizer, model_path=args.load_path)
160
161 # Run training
162 run_train(
163 args,
164 model,
165 device,
166 train_context,
167 optimizer,
168 loss_fn,
169 outdir,
170 train_pos_dataset,
171 neg_sampler,
172 num_items,
173 )
174
175 # Save trained model
176 if args.save_path != "":
177 save_model(model, optimizer, outdir, model_path=args.save_path)
178
179 # Load positive and negative data for evaluation and create dataloader
180 eval_dataset, samples_per_user = load_eval_data(
181 scaled_data_dir,
182 num_users,
183 args.user_scaling,
184 args.item_scaling,
185 args.eval_neg_ratio,
186 )
187 users_per_eval_batch = max(args.eval_batch_size // samples_per_user, 1)
188 eval_dataset = ConcatDataset(
189 [
190 eval_dataset,
191 generate_padding(len(eval_dataset), users_per_eval_batch, samples_per_user),
192 ]
193 )
194 eval_dataloader = DataLoader(
195 eval_dataset,
196 batch_size=users_per_eval_batch,
197 shuffle=False,
198 num_workers=args.loader_num_workers,
199 )
200
201 # Switch to evaluation context
202 Context.switch_context(eval_context)
203
204 # Run evaluation
205 run_eval(
206 args,
207 model,
208 device,
209 eval_context,
210 outdir,
211 eval_dataloader,
212 samples_per_user,
213 num_users,
214 )
215
216
217if __name__ == "__main__":
218 parser = argparse.ArgumentParser()
219
220 # mlsdk options
221 parser.add_argument("--device", type=str, default="mncore2:auto")
222 parser.add_argument("--outdir", type=str, default="/tmp/mlsdk_ncf_training/out")
223 parser.add_argument(
224 "--option_json",
225 type=Path,
226 default="/opt/pfn/pfcomp/codegen/preset_options/O1.json",
227 )
228
229 apply_toml_defaults(str(Path(__file__).parent / "configs.toml"), parser)
230
231 # Parse command line args and opts
232 args = parser.parse_args()
233
234 main(args)
1import argparse
2import math
3
4import torch
5from mlsdk import CompiledFunction, Context
6from utility import compile_fn
7
8
9# Measure inference accuracy as hit ratio (HR) and normalized documented cumulative gain (NDCG)
10def measure_acc( # noqa: CFQ002
11 args: argparse.Namespace,
12 device: str,
13 dataloader: torch.utils.data.DataLoader,
14 samples_per_user: int,
15 infer_fn: CompiledFunction,
16 K: int,
17 num_user: int,
18) -> None:
19 log_2 = math.log(2)
20
21 hits = torch.tensor(0.0)
22 ndcg = torch.tensor(0.0)
23
24 with torch.no_grad():
25 for user, item, dup_mask, pos_item_indices in dataloader:
26 samples = {
27 "user": user.view(-1),
28 "item": item.view(-1),
29 }
30
31 scores = infer_fn(samples)["result"].detach().view(-1, samples_per_user)
32
33 # Set scores of duplicate items to -1 to exclude them from top-k
34 scores[dup_mask.bool()] = -1
35 _, top_k_indices = torch.topk(scores, K)
36
37 # Check if the positive item is among the top-k recommendations (a "hit")
38 hit_mask = top_k_indices == pos_item_indices.unsqueeze(1)
39 hits += hit_mask.sum().item()
40
41 # Find normalized documented cumulative gain (NDCG)
42 hit_ranks = torch.nonzero(hit_mask)[:, 1].view(-1).to(torch.float)
43 ndcg += (log_2 / (hit_ranks + 2).log_()).sum()
44
45 hit_rate = hits.item() / num_user
46 ndcg = ndcg.item() / num_user
47 print(f"HR@{K} = {hit_rate:.4f}, NDCG@{K} = {ndcg:.4f}")
48
49
50def run_eval( # noqa: CFQ002
51 args: argparse.Namespace,
52 model: torch.nn.Module,
53 device: str,
54 context: Context,
55 outdir: str,
56 dataloader: torch.utils.data.DataLoader,
57 samples_per_user: int,
58 num_user: int,
59) -> None:
60 # Define inference functions
61 def infer_fx2onnx(sample_d: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
62 output = None
63 with torch.no_grad():
64 output = model(sample_d["user"], sample_d["item"], sigmoid=True)
65
66 return {"result": output}
67
68 # Sample input for compilation
69 user, item = next(iter(dataloader))[:2]
70 sample = {
71 "user": user.view(-1),
72 "item": item.view(-1),
73 }
74
75 infer_fn = compile_fn(
76 context,
77 infer_fx2onnx,
78 model,
79 sample,
80 outdir=outdir,
81 model_name="ncf_eval",
82 is_train=False,
83 option_json=str(args.option_json),
84 )
85
86 model.eval()
87
88 measure_acc(
89 args, device, dataloader, samples_per_user, infer_fn, args.topk, num_user
90 )
1import os
2import pickle
3
4import numpy as np
5import torch
6
7# import from externals/mlcommons-ncf/recommendation/pytorch
8from alias_generator import AliasSample
9from convert import CACHE_FN, generate_negatives
10from mlsdk import MNCoreOptimizer
11from torch.utils.data import TensorDataset
12
13
14def load_sampler(
15 data_dir: str, user_scaling: int, item_scaling: int
16) -> tuple[AliasSample, np.ndarray, np.ndarray, np.int64]:
17 fn_prefix = data_dir + "/" + CACHE_FN.format(user_scaling, item_scaling)
18 sampler_cache = fn_prefix + "cached_sampler.pkl"
19
20 if os.path.exists(data_dir):
21 print(f"Using alias file: {sampler_cache}")
22 with open(sampler_cache, "rb") as f:
23 sampler, pos_users, pos_items, num_items, _ = pickle.load(f)
24 else:
25 raise ValueError(f"sampler directory does not exist: {data_dir}")
26
27 return (sampler, pos_users, pos_items, num_items)
28
29
30def generate_neg_dataset(
31 pos_users: torch.Tensor,
32 label_size: torch.Size,
33 neg_ratio: int,
34 num_items: np.int64,
35 allow_collision: bool,
36 sampler: AliasSample,
37) -> TensorDataset:
38 if allow_collision:
39 neg_users = pos_users.repeat(neg_ratio)
40 neg_items = torch.empty_like(neg_users, dtype=torch.int64).random_(
41 0, int(num_items)
42 )
43 else:
44 # Use sampler which had been generated in convert.py
45 # The sampler avoids collision of item id between positive and negative data
46 negatives = generate_negatives(sampler, neg_ratio, pos_users.numpy())
47 negatives = torch.from_numpy(negatives)
48 neg_users = negatives[:, 0]
49 neg_items = negatives[:, 1]
50 neg_labels = torch.zeros(label_size, dtype=torch.float32).repeat(neg_ratio)
51
52 return TensorDataset(neg_users, neg_items, neg_labels)
53
54
55def load_train_pos_data(
56 data_dir: str, user_scaling: int, item_scaling: int
57) -> tuple[TensorDataset, int, np.int64, AliasSample]:
58 sampler, pos_users, pos_items, num_items = load_sampler(
59 data_dir, user_scaling, item_scaling
60 )
61
62 num_users = len(sampler.num_regions)
63 pos_users = torch.from_numpy(pos_users).type(torch.LongTensor)
64 pos_items = torch.from_numpy(pos_items).type(torch.LongTensor)
65 pos_labels = torch.ones_like(pos_users, dtype=torch.float32)
66 dataset = TensorDataset(pos_users, pos_items, pos_labels)
67
68 return dataset, num_users, num_items, sampler
69
70
71def load_eval_data(
72 data_dir: str, num_users: int, user_scaling: int, item_scaling: int, neg_ratio: int
73) -> tuple[TensorDataset, int]:
74 # Load positive items
75 pos_item_chunks = []
76 for chunk_id in range(user_scaling):
77 pos_ratings = torch.from_numpy(
78 np.load(
79 f"{data_dir}/testx{user_scaling}x{item_scaling}_{chunk_id}.npz",
80 encoding="bytes",
81 )["arr_0"]
82 )
83 pos_item_chunks.append(pos_ratings[:, 1].reshape(-1, 1))
84
85 # Load negative items
86 neg_item_chunks = []
87 for chunk_id in range(user_scaling):
88 neg_ratings = torch.from_numpy(
89 np.load(
90 f"{data_dir}/test_negx{user_scaling}x{item_scaling}_{chunk_id}.npz",
91 encoding="bytes",
92 )["arr_0"]
93 )
94 neg_item_chunks.append(neg_ratings[:, 1].reshape(-1, neg_ratio))
95
96 # Concat positive and negative items
97 item_chunks = [
98 torch.cat((negs, poses), dim=1)
99 for negs, poses in zip(neg_item_chunks, pos_item_chunks)
100 ]
101
102 # Get indices of positive items in concatenated items
103 pos_item_index_chunks = []
104 for items, pos_items in zip(item_chunks, pos_item_chunks):
105 is_positive_mask = items == pos_items
106 pos_item_index_chunks.append(torch.argmax(is_positive_mask.long(), dim=1))
107
108 # Create a mask to identify duplicate items to avoid them during evaluation
109 dup_mask_chunks = []
110 for items in item_chunks:
111 stable_indices = torch.argsort(items, dim=1, stable=True)
112 sorted_items = torch.gather(items, 1, stable_indices)
113
114 is_duplicate_sorted = sorted_items[:, 1:] == sorted_items[:, :-1]
115 dup_mask_sorted = torch.cat(
116 [
117 torch.zeros(is_duplicate_sorted.shape[0], 1, dtype=torch.bool),
118 is_duplicate_sorted,
119 ],
120 dim=1,
121 )
122
123 # Unsort the mask back to the original item order
124 inverse_indices = torch.argsort(stable_indices, dim=1)
125 dup_mask = torch.gather(dup_mask_sorted, 1, inverse_indices)
126 dup_mask_chunks.append(dup_mask)
127
128 # Concatenate all chunks into final Tensors
129 items = torch.cat(item_chunks, dim=0).long()
130 dup_mask = torch.cat(dup_mask_chunks, dim=0)
131 pos_item_indices = torch.cat(pos_item_index_chunks, dim=0)
132
133 # Replicate each user ID for the number of item samples they have
134 users = torch.arange(num_users, dtype=torch.long).unsqueeze(1)
135 users = users.repeat(1, items.shape[1])
136
137 dataset = TensorDataset(users, items, dup_mask, pos_item_indices)
138 samples_per_user = items.size(1)
139
140 return dataset, samples_per_user
141
142
143def generate_padding(
144 data_len: int, users_per_batch: int, samples_per_user: int
145) -> TensorDataset:
146 remainder_users = data_len % users_per_batch
147 padding_users = users_per_batch - remainder_users if remainder_users > 0 else 0
148
149 dummy_users = torch.zeros(padding_users, samples_per_user, dtype=torch.long)
150 dummy_items = torch.zeros(padding_users, samples_per_user, dtype=torch.long)
151 dummy_dup_mask = torch.zeros(padding_users, samples_per_user, dtype=torch.bool)
152 dummy_pos_item_indices = torch.full((padding_users,), -1, dtype=torch.long)
153
154 return TensorDataset(
155 dummy_users, dummy_items, dummy_dup_mask, dummy_pos_item_indices
156 )
157
158
159def load_model(
160 model: torch.nn.Module,
161 optimizer: MNCoreOptimizer | None = None,
162 model_path: str | os.PathLike = "./ncf_model.pth",
163) -> None:
164 print(f"loading model from {model_path}")
165 weights = torch.load(model_path, weights_only=True)
166
167 model.load_state_dict(weights["model"])
168 if optimizer is not None:
169 optimizer.load_state_dict(weights["optimizer"])
170
171
172def save_model(
173 model: torch.nn.Module,
174 optimizer: MNCoreOptimizer,
175 outdir: str,
176 model_path: str | os.PathLike = "ncf_model.pth",
177) -> None:
178 torch.save(
179 {
180 "model": model.state_dict(),
181 "optimizer": optimizer.state_dict(),
182 },
183 os.path.join(outdir, model_path),
184 )
185
186 print(f"model saved to {model_path}")
1import argparse
2import os
3import pathlib
4import random
5import sys
6from collections.abc import Callable
7from typing import Any
8
9import numpy as np
10import tomllib
11import torch
12from fx2onnx import set_tensor_name
13from mlsdk import (
14 CacheOptions,
15 CompiledFunction,
16 Context,
17 MNCoreOptimizer,
18 get_tensor_name,
19 set_buffer_name_in_optimizer,
20 set_tensor_name_in_module,
21 storage,
22)
23
24
25def set_deterministic_mode(seed: int) -> None:
26 # Set seed
27 random.seed(seed)
28 np.random.seed(seed)
29 torch.manual_seed(seed)
30 torch.cuda.manual_seed(seed)
31
32 # Set cudnn.benchmark mode and specify the use of deterministic algorithms
33 torch.backends.cudnn.benchmark = False
34 torch.use_deterministic_algorithms(True)
35
36
37def register_model(
38 context: Context,
39 name: str,
40 model: torch.nn.Module,
41) -> None:
42 if (
43 get_tensor_name(next(model.parameters())) is None
44 ): # in case the model obj isn't registered to the context
45 set_tensor_name_in_module(model, name)
46 for p in model.parameters():
47 context.register_param(p)
48 for b in model.buffers():
49 context.register_buffer(b)
50
51
52def compile_fn( # noqa: CFQ002
53 context: Context,
54 target_fn: Callable[
55 [
56 dict[str, torch.Tensor],
57 ],
58 dict[str, torch.Tensor],
59 ], # compiled fn
60 model: torch.nn.Module | dict[str, torch.nn.Module],
61 sample_input: dict[str, torch.Tensor],
62 outdir: str = "/tmp/example_output",
63 model_name: str = "example",
64 is_train: bool = True,
65 optimizers: (
66 list[MNCoreOptimizer] | None
67 ) = None, # list[] is for multiple optimizers
68 option_json: str = "/opt/pfn/pfcomp/codegen/preset_options/O1.json",
69 preset_options_dir: str | None = None,
70 enable_cache: bool = False,
71 **kwargs: Any, # used in `compile_args` in Context.compile()
72) -> CompiledFunction:
73
74 if preset_options_dir is None:
75 preset_options_dir = pathlib.Path.cwd().parent.parent.parent / "preset_options"
76
77 compile_options = {"option_json": option_json}
78
79 compile_args = {
80 "function": target_fn,
81 "inputs": sample_input,
82 "options": compile_options,
83 }
84
85 codegen_base_dir = storage.path(outdir)
86 compile_args["codegen_dir"] = codegen_base_dir / model_name
87
88 if enable_cache:
89 compile_args["cache_options"] = CacheOptions(
90 f"{outdir}/{model_name}/cache",
91 enable_app_cache=True,
92 enable_onnx_cache=True,
93 enable_codegen_cache=True,
94 enable_gpfn2obj_cache=True,
95 )
96
97 if isinstance(model, torch.nn.Module):
98 register_model(context, model_name, model)
99 else: # if isinstance(models, dict[str, torch.nn.Module]):
100 for name, actual_model in model.items():
101 register_model(context, name, actual_model)
102
103 if is_train:
104 if optimizers is None: # in case that optimizer.step() will be done at the host
105 if isinstance(model, torch.nn.Module):
106 for n, p in model.named_parameters():
107 p.grad = torch.nn.Parameter(
108 torch.zeros_like(p), requires_grad=p.requires_grad
109 )
110 set_tensor_name(p.grad, f"{model_name}@{n}@grad".replace(".", "_"))
111 context.register_param(p.grad)
112 else:
113 for name, actual_model in model.items():
114 for n, p in actual_model.named_parameters():
115 p.grad = torch.nn.Parameter(
116 torch.zeros_like(p), requires_grad=p.requires_grad
117 )
118 set_tensor_name(p.grad, f"{name}@{n}".replace(".", "_"))
119 context.register_param(p.grad)
120 else:
121 for idx, optimizer in enumerate(optimizers):
122 optimizer_name = "optimizer" + str(idx)
123 set_buffer_name_in_optimizer(optimizer, optimizer_name)
124 context.register_optimizer_buffers(optimizer)
125
126 compile_args.update(kwargs)
127
128 return context.compile(**compile_args)
129
130
131# for type hint of the configs from toml
132class TomlValue:
133 str | int | float | bool | list["TomlValue"] | dict[str, "TomlValue"]
134
135
136class TomlDict:
137 dict[str, TomlValue]
138
139
140def read_configs_from_toml(
141 toml_path: str,
142) -> TomlDict:
143
144 configs_dict = None
145 with open(toml_path, mode="rb") as f:
146 configs_dict = tomllib.load(f)
147
148 return configs_dict
149
150
151def str2bool(v: bool | str) -> bool:
152 if v.lower() in ("yes", "true", "on", "enable", "y", "t", "1"):
153 return True
154 elif v.lower() in ("no", "false", "off", "disable", "n", "f", "0"):
155 return False
156 elif isinstance(v, str | bool):
157 return v
158 else:
159 raise argparse.ArgumentTypeError("Str or boolean value expected")
160
161
162def apply_toml_defaults(
163 configs: TomlDict | str | os.PathLike,
164 parser: argparse.ArgumentParser,
165) -> None:
166
167 if isinstance(configs, dict):
168 for k, v in configs.items():
169 if isinstance(v, dict): # in case v is (nested) dict
170 apply_toml_defaults(v, parser)
171 else:
172 # just checking whether v is list is enough for array args
173 # because array in toml is converted to the list by tomllib
174 args_type = None
175 if isinstance(v, list):
176 args_type = type(v[0])
177 elif isinstance(v, bool):
178 args_type = str2bool
179 else:
180 args_type = type(v)
181 parser.add_argument(
182 f"--{k}",
183 default=v,
184 type=args_type,
185 nargs="*" if isinstance(v, list) else "?",
186 )
187 elif isinstance(configs, str | os.PathLike):
188 configs_dict = read_configs_from_toml(configs)
189
190 apply_toml_defaults(configs_dict, parser)
191 else:
192 sys.exit("")
1absl-py==0.7.0
2numpy==1.16.2
3pandas==0.24.2
4protobuf==3.19.6
5scikit-image==0.14.2
6scikit-learn==0.20.3
7scipy==1.2.1
8six==1.12.0
9tensorflow==1.13.1
1scipy==1.16.0
2torch==2.9.0
3numpy==2.3.0
4numpy_indexed==0.3.7
5pandas==2.3.0
6mlperf_compliance==0.0.10
1title = "ncf_training"
2
3
4[model]
5factors = 64
6mlp_layers = [256, 256, 128, 64]
7save_path = "./ncf_model.pth"
8load_path = ""
9model_path = "./ncf_model.pth"
10
11
12[dataset]
13dataset = "ml-1m"
14user_scaling = 1 # this value must be the same as specified in preparation.sh
15item_scaling = 1 # this value must be the same as specified in preparation.sh
16train_neg_ratio = 4
17eval_neg_ratio = 999
18allow_collision_with_pos = false
19loader_num_workers = 2
20
21
22[training]
23epoch = 20
24learning_rate = 0.0002
25
26
27[evaluation]
28topk = 10
29
30
31[misc]
32seed = 0
33train_batch_size = 65536 # 2**16
34eval_batch_size = 16384 # 2**14