riccardo ревизій цього gist 9 months ago. До ревизії
1 file changed, 9 insertions, 2 deletions
hparams_save-load.py
| @@ -24,8 +24,14 @@ class MyModel: | |||
| 24 | 24 | hparams = yaml.safe_load(f) | |
| 25 | 25 | return cls(**hparams) | |
| 26 | 26 | ||
| 27 | + | def load_checkpoint(self, checkpoint: Path | str): | |
| 28 | + | print(f"Loading checkpoint from {checkpoint}...") | |
| 29 | + | ||
| 27 | 30 | def train(self): | |
| 28 | 31 | print("Training...") | |
| 32 | + | print("Saving checkpoint...") | |
| 33 | + | with open(Path(self.outdir) / "checkpoint.pt", "w") as f: | |
| 34 | + | f.write("Hello, world!") | |
| 29 | 35 | ||
| 30 | 36 | def evaluate(self): | |
| 31 | 37 | print("Evaluating...") | |
| @@ -54,5 +60,6 @@ with open(outdir / "hparams.yaml", "w") as f: | |||
| 54 | 60 | yaml.dump(hparams, f) | |
| 55 | 61 | ||
| 56 | 62 | # Evaluate the model | |
| 57 | - | model = MyModel.from_config(outdir / "hparams.yaml") | |
| 58 | - | model.evaluate() | |
| 63 | + | new_model = MyModel.from_config(outdir / "hparams.yaml") | |
| 64 | + | new_model.load_checkpoint(outdir / "checkpoint.pt") | |
| 65 | + | new_model.evaluate() | |
riccardo ревизій цього gist 9 months ago. До ревизії
1 file changed, 58 insertions
hparams_save-load.py(файл створено)
| @@ -0,0 +1,58 @@ | |||
| 1 | + | from datetime import datetime | |
| 2 | + | from pathlib import Path | |
| 3 | + | ||
| 4 | + | import yaml | |
| 5 | + | ||
| 6 | + | ||
| 7 | + | # Define the model | |
| 8 | + | class MyModel: | |
| 9 | + | def __init__( | |
| 10 | + | self, | |
| 11 | + | n_layers: int, | |
| 12 | + | n_heads: int, | |
| 13 | + | dropout: float, | |
| 14 | + | outdir: str | Path, | |
| 15 | + | ): | |
| 16 | + | self.n_layers = n_layers | |
| 17 | + | self.n_heads = n_heads | |
| 18 | + | self.dropout = dropout | |
| 19 | + | self.outdir = outdir | |
| 20 | + | ||
| 21 | + | @classmethod | |
| 22 | + | def from_config(cls, config: Path | str): | |
| 23 | + | with open(config, "r") as f: | |
| 24 | + | hparams = yaml.safe_load(f) | |
| 25 | + | return cls(**hparams) | |
| 26 | + | ||
| 27 | + | def train(self): | |
| 28 | + | print("Training...") | |
| 29 | + | ||
| 30 | + | def evaluate(self): | |
| 31 | + | print("Evaluating...") | |
| 32 | + | ||
| 33 | + | ||
| 34 | + | # Create the output directory | |
| 35 | + | now = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| 36 | + | outdir = Path(f"{now}_runs") | |
| 37 | + | outdir.mkdir(parents=True, exist_ok=True) | |
| 38 | + | ||
| 39 | + | # Instantiate the model | |
| 40 | + | hparams = { | |
| 41 | + | "n_layers": 3, | |
| 42 | + | "n_heads": 8, | |
| 43 | + | "dropout": 0.1, | |
| 44 | + | "outdir": str(outdir), | |
| 45 | + | } | |
| 46 | + | ||
| 47 | + | model = MyModel(**hparams) | |
| 48 | + | ||
| 49 | + | # Train the model | |
| 50 | + | model.train() | |
| 51 | + | ||
| 52 | + | # Save the hyperparameters in the same directory | |
| 53 | + | with open(outdir / "hparams.yaml", "w") as f: | |
| 54 | + | yaml.dump(hparams, f) | |
| 55 | + | ||
| 56 | + | # Evaluate the model | |
| 57 | + | model = MyModel.from_config(outdir / "hparams.yaml") | |
| 58 | + | model.evaluate() | |