Equivariant Architectures for Learning in Deep Weight Spaces

Aviv Navon, Aviv Shamsian, Idan Achituve, Ethan Fetaya, Gal Chechik, Haggai Maron

Research output: Contribution to journalConference articlepeer-review

Abstract

Designing machine learning architectures for processing neural networks in their raw weight matrix form is a newly introduced research direction. Unfortunately, the unique symmetry structure of deep weight spaces makes this design very challenging. If successful, such architectures would be capable of performing a wide range of intriguing tasks, from adapting a pre-trained network to a new domain to editing objects represented as functions (INRs or NeRFs). As a first step towards this goal, we present here a novel network architecture for learning in deep weight spaces. It takes as input a concatenation of weights and biases of a pre-trained MLP and processes it using a composition of layers that are equivariant to the natural permutation symmetry of the MLP's weights: Changing the order of neurons in intermediate layers of the MLP does not affect the function it represents. We provide a full characterization of all affine equivariant and invariant layers for these symmetries and show how these layers can be implemented using three basic operations: pooling, broadcasting, and fully connected layers applied to the input in an appropriate manner. We demonstrate the effectiveness of our architecture and its advantages over natural baselines in a variety of learning tasks.

Original languageEnglish
Pages (from-to)25790-25816
Number of pages27
JournalProceedings of Machine Learning Research
Volume202
StatePublished - 2023
Event40th International Conference on Machine Learning, ICML 2023 - Honolulu, United States
Duration: 23 Jul 202329 Jul 2023

Bibliographical note

Publisher Copyright:
© 2023 Proceedings of Machine Learning Research. All rights reserved.

Fingerprint

Dive into the research topics of 'Equivariant Architectures for Learning in Deep Weight Spaces'. Together they form a unique fingerprint.

Cite this