🧜‍♂️ Poseidon: Efficient Foundation Models for PDEs

Maximilian Herde1,*, Bogdan Raonić1,2,*, Tobias Rohner1, Roger Käppeli1, Roberto Molinaro1, Emmanuel de Bézenac1, Siddhartha Mishra1,2
1Seminar for Applied Mathematics, ETH Zurich, Switzerland, 2ETH AI Center, Zurich, Switzerland, *Equal contribution
Figure 1

PDEgym

We introduce PDEgym, a collection of datasets consisting of multiple PDEs and operators.

Use the dropdown menu below to see visualizations of samples of time-dependent datasets with normalized time.

Use the dropdown menu below to see visualizations of samples of time-independent datasets. The top row corresponds to inputs, the bottom to outputs. For Helmholtz, the number below each sample is the value of the Dirichlet boundary condition.

Selected Image

Usage

We encourage using our pretrained models on your own datasets. To that end, you can directly plug your dataset into our code and then finetune using our scripts. With your data at hand, you may follow these steps:

  1. Put your dataset in an appropriate directory. Then, write a dataset class in a file in the scOT/problems subfolder of the code that inherits from BaseDataset or BaseTimeDataset in scOT/problems/base.py, depending on whether your problem is time-independent or time-dependent, respectively. Make sure to also pass *args, **kwargs as arguments. A minimal example of which fields have to be set additionally can be seen in the snippet below (for a time-dependent problem).
    
    class MyDataset(BaseTimeDataset):
        def __init__(self, *args, **kwargs):
            # the following arguments are inherited:
            # which: str, can be train, val, or test, depending on the split to get
            # num_trajectories: int, number of trajectories/samples in the training set to get
            # data_path: str, path to the data directory
            # move_to_local_scratch: str, set to a directory that offers higher bandwidth, if applicable, data will be copied there before training
            # max_num_time_steps: int, maximum number of time steps to consider (only relevant for time-dependent)
            # time_step_size: int, time step size (only relevant for time-dependent)
            # fix_input_to_time_step: int, if set, the input will be fixed to this time step (only relevant for time-dependent)
            # allowed_time_transitions: list of int, allowed time transitions, if you don't want to use full all2all (only relevant for time-dependent)
            super().__init__(*args, **kwargs)
    
            self.N_max = None # maximum number of trajectories/samples in full dataset
            self.N_val = None # number of validation samples
            self.N_test = None # number of test samples
    
            self.input_dim = None # input dimensionality for the model
            self.label_description = None # a str build as "[channel1],[channel2,channel3]" 
            # where channel1, channel2, channel3 are the output channels 
            # named as channel1, channel2, channel3; channel2 and channel3 will be interpreted 
            # as a vectorized function in the loss and evaluation
    
            # you may want to call self._move_to_local_scratch(path_to_dataset) 
            # if you want to allow copying before opening the dataset
    
            self.post_init()
        
        # __len__() is inherited and does not need to be implemented
    
        def __getitem__(self, idx):
            i, t, t1, t2 = self._idx_map(idx) # for time-dependent problems: i is the trajectory, t the lead time, t1 the initial time, t2 the final time
    
            # get your data using i + self.start for the trajectory, idx + self.start for the sample in the time-independent case
            
            # return a dict with input, output, lead time (in case of time-dependent problems)
            # inputs should have shape (self.input_dim, resolution, resolution)
            # outputs should have shape (output_dim, resolution, resolution)
            return {"pixel_values": inputs, "labels": outputs, "time": time}
                    
    After implementing the dataset, you may add it to the getter method get_dataset in scOT/problems/base.py such that a string can identify the dataset - this is used by the training script. You may also add default settings there.
  2. Write your own configuration file, see the examples in the configs subdirectory.
  3. Run the finetuning/training script to finetune the model on your dataset.
    
    accelerate launch scOT/train.py \
        --config YOUR_WANDB_CONFIG_FILE \
        --wandb_run_name WANDB_RUN_NAME \
        --wandb_project_name WANDB_PROJECT_NAME \
        --checkpoint_path CHECKPOINT_PATH \
        --data_path DATA_PATH \
        --finetune_from PRETRAINED_MODEL \
        --replace_embedding_recovery SET ONLY IF EMBED/RECOVERY NEEDS TO BE REPLACED
                        
    The script will automatically log to Weights & Biases and save the best model to the checkpoint path.

Cite

If you find our work useful, please cite our paper:


@misc{herde2024poseidon,
    title={Poseidon: Efficient Foundation Models for PDEs}, 
    author={Maximilian Herde and Bogdan Raonić and Tobias Rohner and Roger Käppeli and Roberto Molinaro and Emmanuel de Bézenac and Siddhartha Mishra},
    year={2024},
    eprint={2405.19101},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}