We introduce PDEgym, a collection of datasets consisting of multiple PDEs and operators.
Use the dropdown menu below to see visualizations of samples of time-dependent datasets with normalized time.
Use the dropdown menu below to see visualizations of samples of time-independent datasets. The top row corresponds to inputs, the bottom to outputs. For Helmholtz, the number below each sample is the value of the Dirichlet boundary condition.
We encourage using our pretrained models on your own datasets. To that end, you can directly plug your dataset into our code and then finetune using our scripts. With your data at hand, you may follow these steps:
scOT/problems
subfolder of the code that inherits
from BaseDataset
or BaseTimeDataset
in
scOT/problems/base.py
, depending on whether your problem is time-independent or
time-dependent, respectively. Make sure to also pass *args, **kwargs
as arguments. A minimal example of which fields have to be set additionally can be seen in the
snippet below (for a time-dependent problem).
class MyDataset(BaseTimeDataset):
def __init__(self, *args, **kwargs):
# the following arguments are inherited:
# which: str, can be train, val, or test, depending on the split to get
# num_trajectories: int, number of trajectories/samples in the training set to get
# data_path: str, path to the data directory
# move_to_local_scratch: str, set to a directory that offers higher bandwidth, if applicable, data will be copied there before training
# max_num_time_steps: int, maximum number of time steps to consider (only relevant for time-dependent)
# time_step_size: int, time step size (only relevant for time-dependent)
# fix_input_to_time_step: int, if set, the input will be fixed to this time step (only relevant for time-dependent)
# allowed_time_transitions: list of int, allowed time transitions, if you don't want to use full all2all (only relevant for time-dependent)
super().__init__(*args, **kwargs)
self.N_max = None # maximum number of trajectories/samples in full dataset
self.N_val = None # number of validation samples
self.N_test = None # number of test samples
self.input_dim = None # input dimensionality for the model
self.label_description = None # a str build as "[channel1],[channel2,channel3]"
# where channel1, channel2, channel3 are the output channels
# named as channel1, channel2, channel3; channel2 and channel3 will be interpreted
# as a vectorized function in the loss and evaluation
# you may want to call self._move_to_local_scratch(path_to_dataset)
# if you want to allow copying before opening the dataset
self.post_init()
# __len__() is inherited and does not need to be implemented
def __getitem__(self, idx):
i, t, t1, t2 = self._idx_map(idx) # for time-dependent problems: i is the trajectory, t the lead time, t1 the initial time, t2 the final time
# get your data using i + self.start for the trajectory, idx + self.start for the sample in the time-independent case
# return a dict with input, output, lead time (in case of time-dependent problems)
# inputs should have shape (self.input_dim, resolution, resolution)
# outputs should have shape (output_dim, resolution, resolution)
return {"pixel_values": inputs, "labels": outputs, "time": time}
After implementing the dataset, you may add it to the getter method get_dataset
in scOT/problems/base.py
such that a string
can
identify the dataset - this is used by the training script. You may also add default settings
there.
configs
subdirectory.
accelerate launch scOT/train.py \
--config YOUR_WANDB_CONFIG_FILE \
--wandb_run_name WANDB_RUN_NAME \
--wandb_project_name WANDB_PROJECT_NAME \
--checkpoint_path CHECKPOINT_PATH \
--data_path DATA_PATH \
--finetune_from PRETRAINED_MODEL \
--replace_embedding_recovery SET ONLY IF EMBED/RECOVERY NEEDS TO BE REPLACED
The script will automatically log to Weights & Biases and save the
best
model to the checkpoint path.
If you find our work useful, please cite our paper:
@misc{herde2024poseidon,
title={Poseidon: Efficient Foundation Models for PDEs},
author={Maximilian Herde and Bogdan Raonić and Tobias Rohner and Roger Käppeli and Roberto Molinaro and Emmanuel de Bézenac and Siddhartha Mishra},
year={2024},
eprint={2405.19101},
archivePrefix={arXiv},
primaryClass={cs.LG}
}