Dernière activité 9 months ago

Save and load hyperparameters for model training and evaluation

hparams_save-load.py Brut
1from datetime import datetime
2from pathlib import Path
3
4import yaml
5
6
7# Define the model
8class MyModel:
9 def __init__(
10 self,
11 n_layers: int,
12 n_heads: int,
13 dropout: float,
14 outdir: str | Path,
15 ):
16 self.n_layers = n_layers
17 self.n_heads = n_heads
18 self.dropout = dropout
19 self.outdir = outdir
20
21 @classmethod
22 def from_config(cls, config: Path | str):
23 with open(config, "r") as f:
24 hparams = yaml.safe_load(f)
25 return cls(**hparams)
26
27 def load_checkpoint(self, checkpoint: Path | str):
28 print(f"Loading checkpoint from {checkpoint}...")
29
30 def train(self):
31 print("Training...")
32 print("Saving checkpoint...")
33 with open(Path(self.outdir) / "checkpoint.pt", "w") as f:
34 f.write("Hello, world!")
35
36 def evaluate(self):
37 print("Evaluating...")
38
39
40# Create the output directory
41now = datetime.now().strftime("%Y%m%d_%H%M%S")
42outdir = Path(f"{now}_runs")
43outdir.mkdir(parents=True, exist_ok=True)
44
45# Instantiate the model
46hparams = {
47 "n_layers": 3,
48 "n_heads": 8,
49 "dropout": 0.1,
50 "outdir": str(outdir),
51}
52
53model = MyModel(**hparams)
54
55# Train the model
56model.train()
57
58# Save the hyperparameters in the same directory
59with open(outdir / "hparams.yaml", "w") as f:
60 yaml.dump(hparams, f)
61
62# Evaluate the model
63new_model = MyModel.from_config(outdir / "hparams.yaml")
64new_model.load_checkpoint(outdir / "checkpoint.pt")
65new_model.evaluate()
66