SRM 401 DIV1 Hard NCool

問題 Editorial

問題

凸多角形(x,y)が与えられる。
ある整数点がN-coolであるとは、多角形の中(辺上も含む)にあって、少なくとも1つのN-coolな線分の端点となっていることである。
ある線分がN-coolであるとは、多角形の中にある整数点を少なくともN個含むことである。
nが与えられる。n-coolな整数点を数えよ。

  • 3 <= |x| = |y| <= 50
  • 0 <= x[i], y[i] <= 10000
  • 与えられる多角形の中に含まれる点の数は500000を超えない
  • 2 <= n <= 500000

解答

Editorialを参考にした。
まず、点の数の制約から全ての点を列挙することができることはわかる。
実際に列挙するには平面走査をする。凸多角形なので、上部と下部に分けることが出来る。この上部と下部は凸包を求めるアルゴリズムで得られる。
範囲内のそれぞれのx座標に対し、上部と下部の対応するy座標の区間の範囲を列挙すればよい。このとき、floorとceilをする整数割り算関数を負の数にも対応するように適当に実装する必要がある。
しかしある点がn-coolであるかをどのように判定するのか。
まず、N-coolな線分を「N個以上含む」の代わりに「N個ちょうど含む」としても変わらないことがわかる。また、明らかに端点は整数点としてよい。
ここで重要な定理が存在する:
任意の整数(N≧2),(0≦x,y)に対し、(0,0)と(x,y)を結ぶ線分がちょうどNつの整数点を含む <-> x ≡ 0 (mod N-1) and y ≡ 0 (mod N-1) and gcd(x/(N-1), y/(N-1)) = 1
なんとなくはわかるけれど証明できない…
とにかく、これによって簡単にできる。まず、gcd = 1の制約は無視してよい。なぜなら、gcdが1でなくても片方の端点のx,y座標を適当にk(N-1) (kは整数)動かせば、もう片方の端点の座標を保存したままgcd = 1とできるから(これは多角形が凸であるからできる。凸でないと必ずしも中間の点を取れないので)。
あとは組(x mod N-1, y mod N-1)で決定される同値類を考えて、2つ以上多角形内の点が含まれるような類に対してのその数を総和すればよい。これはmapで数えればよい。

コード

#include <vector>
#include <algorithm>
#include <map>
#define rep(i,n) for(int (i)=0;(i)<(int)(n);++(i))
#define each(it,o) for(__typeof((o).begin()) it = (o).begin(); it != (o).end(); ++ it)
#define pb(x) push_back(x)
#define mp(x,y) make_pair((x),(y))
using namespace std;
typedef pair<int,int> pii; typedef long long ll;   

struct P {
	typedef int T; typedef ll T2;	//T2は{a*b | a:T, b:T}を含むタイプ
	T x, y;
	P(T x_, T y_): x(x_), y(y_) {}
	P(): x(0), y(0) {}
};
inline bool operator==(const P& a, const P& b) { return a.x == b.x && a.y == b.y; }
inline bool operator<(const P& a, const P& b) { return a.x < b.x || (a.x == b.x && a.y < b.y); }

inline int ccw(const P& a, const P& b, const P& c) {
	int ax = b.x - a.x, ay = b.y - a.y, bx = c.x - a.x, by = c.y - a.y;
	P::T2 t = (P::T2)ax*by - (P::T2)ay*bx;
	if (t > 0) return 1;
	if (t < 0) return -1;
	if((P::T2)ax*bx + (P::T2)ay*by < 0) return +2;
	if((P::T2)ax*ax + (P::T2)ay*ay < (P::T2)bx*bx + (P::T2)by*by) return -2;
	return 0;
}

void lower_and_upper_convex_hull(vector<P> ps, vector<P> &lower, vector<P> &upper) {
	int n = ps.size(), k = 0;
	sort(ps.begin(), ps.end());
	vector<P> ch(2*n);
	for (int i = 0; i < n; ch[k++] = ps[i++]) // lower-hull
		while (k >= 2 && ccw(ch[k-2], ch[k-1], ps[i]) <= 0) --k;
	int t = k+1;
	for (int i = n-2; i >= 0; ch[k++] = ps[i--]) // upper-hull
		while (k >= t && ccw(ch[k-2], ch[k-1], ps[i]) <= 0) --k;
	lower.assign(ch.begin(), ch.begin()+(t >= 3 && ch[t-2].x == ch[t-3].x ? t-2 : t-1));
	upper.assign(ch.begin()+(t-2), ch.begin()+(k >= 2 && ch[k-2].x == ch[k-1].x ? k-1 : k));
	reverse(upper.begin(), upper.end());
}

template<typename T, typename U>
inline auto floordiv(T x, U y) -> decltype(x/y) {
	auto q = x / y, r = x % y;
	return q - ((r!=0) & ((r<0) ^ (y<0)));
}
template<typename T, typename U>
inline auto ceildiv(T x, U y) -> decltype(x/y) {
	auto q = x / y, r = x % y;
	return q + ((r!=0) & !((r<0) ^ (y<0)));
}

template<typename Func>
void line_sweep_convex(vector<P> ps, Func func) {
	vector<P> lower, upper;
	lower_and_upper_convex_hull(ps, lower, upper);
	int L = lower.size(), U = upper.size();
	P::T minx = lower[0].x, maxx = lower[L-1].x;
	int l = 0, u = 0;
	for(P::T x = minx; x < maxx; x ++) {
		if(lower[l+1].x <= x) l ++;
		if(upper[u+1].x <= x) u ++;
		//invariant: lower[l].x <= x < lower[l+1].x, upper[u].x <= x < upper[u+1].x
		//lower[l].y + (lower[l+1].y - lower[l].y) / (lower[l+1].x - lower[l].x) * (x - lower[l].x) <= y <= ...
		P::T xl = lower[l+1].x - lower[l].x, xu = upper[u+1].x - upper[u].x;
		P::T yl = lower[l+1].y - lower[l].y, yu = upper[u+1].y - upper[u].y;
		P::T miny = lower[l].y + ceildiv( (P::T2)yl * (x - lower[l].x), xl);
		P::T maxy = upper[u].y + floordiv((P::T2)yu * (x - upper[u].x), xu);
		func(x, miny, maxy);
	}
	func(maxx, lower[L-1].y, upper[U-1].y);
}
struct NCool {
	int nCoolPoints(vector <int> xs, vector <int> ys, int n) {
		vector<P> ps;
		rep(i, xs.size()) ps.pb(P(xs[i], ys[i]));
		map<pii, int> m;
		line_sweep_convex(ps, [n,&m](int x, int miny, int maxy) -> void {
			for(int y = miny; y <= maxy; y ++) {
				m[mp(x % (n-1), y % (n-1))] ++;
			}
		});
		int r = 0;
		each(i, m) if(i->second > 1) r += i->second;
		return r;
	}
};