001// SPDX-FileCopyrightText: 2021 Paul Schaub <vanitasvitae@fsfe.org>
002//
003// SPDX-License-Identifier: Apache-2.0
004
005package org.pgpainless.algorithm.negotiation;
006
007import java.util.ArrayList;
008import java.util.Collections;
009import java.util.HashMap;
010import java.util.LinkedHashMap;
011import java.util.List;
012import java.util.Map;
013import java.util.Set;
014
015import org.pgpainless.algorithm.SymmetricKeyAlgorithm;
016import org.pgpainless.policy.Policy;
017
018/**
019 * Interface for symmetric key algorithm negotiation.
020 */
021public interface SymmetricKeyAlgorithmNegotiator {
022
023    /**
024     * Negotiate a symmetric encryption algorithm.
025     *
026     * @param policy algorithm policy
027     * @param override algorithm override (if not null, return this)
028     * @param keyPreferences list of preferences per key
029     * @return negotiated algorithm
030     */
031    SymmetricKeyAlgorithm negotiate(Policy.SymmetricKeyAlgorithmPolicy policy, SymmetricKeyAlgorithm override, List<Set<SymmetricKeyAlgorithm>> keyPreferences);
032
033    static SymmetricKeyAlgorithmNegotiator byPopularity() {
034        return new SymmetricKeyAlgorithmNegotiator() {
035            @Override
036            public SymmetricKeyAlgorithm negotiate(Policy.SymmetricKeyAlgorithmPolicy policy, SymmetricKeyAlgorithm override, List<Set<SymmetricKeyAlgorithm>> preferences) {
037                if (override == SymmetricKeyAlgorithm.NULL) {
038                    throw new IllegalArgumentException("Algorithm override cannot be NULL (plaintext).");
039                }
040
041                if (override != null) {
042                    return override;
043                }
044
045                // Count score (occurrences) of each algorithm
046                Map<SymmetricKeyAlgorithm, Integer> supportWeight = new LinkedHashMap<>();
047                for (Set<SymmetricKeyAlgorithm> keyPreferences : preferences) {
048                    for (SymmetricKeyAlgorithm preferred : keyPreferences) {
049                        if (supportWeight.containsKey(preferred)) {
050                            supportWeight.put(preferred, supportWeight.get(preferred) + 1);
051                        } else {
052                            supportWeight.put(preferred, 1);
053                        }
054                    }
055                }
056
057                // Pivot the score map
058                Map<Integer, List<SymmetricKeyAlgorithm>> byScore = new HashMap<>();
059                for (SymmetricKeyAlgorithm algorithm : supportWeight.keySet()) {
060                    int score = supportWeight.get(algorithm);
061                    List<SymmetricKeyAlgorithm> withSameScore = byScore.get(score);
062                    if (withSameScore == null) {
063                        withSameScore = new ArrayList<>();
064                        byScore.put(score, withSameScore);
065                    }
066                    withSameScore.add(algorithm);
067                }
068
069                List<Integer> scores = new ArrayList<>(byScore.keySet());
070
071                // Sort map and iterate from highest to lowest score
072                Collections.sort(scores);
073                for (int i = scores.size() - 1; i >= 0; i--) {
074                    int score = scores.get(i);
075                    List<SymmetricKeyAlgorithm> withSameScore = byScore.get(score);
076                    // Select best algorithm
077                    SymmetricKeyAlgorithm best = policy.selectBest(withSameScore);
078                    if (best != null) {
079                        return best;
080                    }
081                }
082
083                // If no algorithm is acceptable, choose fallback
084                return policy.getDefaultSymmetricKeyAlgorithm();
085            }
086        };
087    }
088}