Skip to content

凸优化

几何中心

给定一些二维平面上的点,求一个目标点使得其到所有给定点的欧氏距离之和最小。

\[ \displaylines{ \min_{x,y}\sum_{i}\sqrt{(x-x_i)^2 + (y-y_i)^2} } \]

没有解析解,但目标函数为凸函数,故使用梯度下降法求解:

class Solution {
public:
    double getMinDistSum(vector<vector<int>>& positions) {
        double eps = 1e-7;
        double alpha = 1;
        double decay = 1e-3;

        int n = positions.size();
        int batchSize = min(n, 32);

        // mean position
        double x = 0.0, y = 0.0;
        for (const auto& pos: positions) {
            x += pos[0];
            y += pos[1];
        }
        x /= n;
        y /= n;

        // random engine
        mt19937 gen{random_device{}()};

        // loop epoch until convergence
        while (true) {
            shuffle(positions.begin(), positions.end(), gen);
            double xPrev = x;
            double yPrev = y;
            // mini-batch SGD
            for (int i = 0; i < n; i += batchSize) {
                int j = min(i + batchSize, n);
                double dx = 0.0, dy = 0.0;
                // accumulate gradient per batch
                for (int k = i; k < j; ++k) {
                    const auto& pos = positions[k];
                    dx += (x - pos[0]) / (sqrt((x - pos[0]) * (x - pos[0]) + (y - pos[1]) * (y - pos[1])) + eps);
                    dy += (y - pos[1]) / (sqrt((x - pos[0]) * (x - pos[0]) + (y - pos[1]) * (y - pos[1])) + eps);
                }
                // update
                x -= alpha * dx;
                y -= alpha * dy;
                // lower lr
                alpha *= (1.0 - decay);
            }
            // determine convergence
            if (sqrt((x - xPrev) * (x - xPrev) + (y - yPrev) * (y - yPrev)) < eps) {
                break;
            }
        }

        // get objective function value
        auto getDist = [&](double xc, double yc) {
            double ans = 0;
            for (const auto& pos: positions) {
                ans += sqrt((pos[0] - xc) * (pos[0] - xc) + (pos[1] - yc) * (pos[1] - yc));
            }
            return ans;
        };        

        return getDist(x, y);
    }
};