# ============================================
# Digit Recognizer - CNN
# UMAP-EDA に基づく クラス別 Data Augmentation + Focal Loss + 5-fold Ensemble
# スコア: 0.99582 付近
# ============================================

import os
import random
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import StratifiedKFold

# --------------------------------------------
# 1. 再現性のための seed 固定
# --------------------------------------------
SEED = 42

def seed_everything(seed=SEED):
    """乱数の種をそろえて、毎回同じ結果が出やすいようにする。"""
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)

seed_everything()

# --------------------------------------------
# 2. データ読み込み（Kaggle の Digit Recognizer 入力パス）
# --------------------------------------------
train = pd.read_csv("/kaggle/input/digit-recognizer/train.csv")
test  = pd.read_csv("/kaggle/input/digit-recognizer/test.csv")

# pixel0, pixel1, ... pixel783 のカラムだけを特徴量として使う
pixel_cols = [c for c in train.columns if c.startswith("pixel")]

# 0〜255 の画素値を 0〜1 に正規化（float32）
X = train[pixel_cols].values.astype("float32") / 255.0
y = train["label"].values.astype("int32")
X_test = test[pixel_cols].values.astype("float32") / 255.0

# CNN 用に (N, 28, 28, 1) に reshape （灰色画像1チャネル）
X = X.reshape(-1, 28, 28, 1)
X_test = X_test.reshape(-1, 28, 28, 1)

NUM_CLASSES = 10

# --------------------------------------------
# 3. Focal Loss
#    → 簡単なサンプルの影響を弱め、間違えやすいサンプルを強調する損失
# --------------------------------------------
def focal_loss(gamma=2.0, alpha=0.25):
    """
    y_true: (batch,) 形式のクラスラベル（0〜9）
    y_pred: softmax 出力 (batch, num_classes)

    ・普通の categorical_crossentropy に
      (1 - p_t)^gamma という重みをかけることで、
      すでに正しく分類できているサンプル（p_t が大きい）を軽く、
      間違えがちなサンプル（p_t が小さい）を重く扱う。
    """
    def loss(y_true, y_pred):
        # one-hot に変換
        y_true_onehot = tf.one_hot(tf.cast(y_true, tf.int32), depth=NUM_CLASSES)
        # 通常のクロスエントロピー
        ce = keras.losses.categorical_crossentropy(y_true_onehot, y_pred)
        # p_t = 正解クラスに対応する確率
        p_t = tf.reduce_sum(y_true_onehot * y_pred, axis=-1)
        # Focal の重み (1 - p_t)^gamma * alpha
        focal_factor = alpha * tf.pow(1.0 - p_t, gamma)
        return focal_factor * ce
    return loss

# --------------------------------------------
# 4. CNN モデル定義
#    → 3ブロックの Conv-BN-ReLU + MaxPool + Dropout
# --------------------------------------------
def build_model(input_shape=(28, 28, 1), num_classes=10):
    inputs = keras.Input(shape=input_shape)
    x = inputs

    # Conv ブロックを 3 回繰り返す
    for filters, dropout in [(32,0.25),(64,0.30),(128,0.40)]:
        # Conv → BN → ReLU を 2 回
        x = layers.Conv2D(filters, 3, padding="same", use_bias=False)(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)
        x = layers.Conv2D(filters, 3, padding="same", use_bias=False)(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)
        # 空間サイズ 1/2 にするプーリング
        x = layers.MaxPooling2D(2)(x)
        # 過学習を防ぐ Dropout
        x = layers.Dropout(dropout)(x)

    # 全結合部分
    x = layers.Flatten()(x)
    x = layers.Dense(256, use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Dropout(0.5)(x)

    outputs = layers.Dense(num_classes, activation="softmax")(x)

    model = keras.Model(inputs, outputs)
    model.compile(
        optimizer=keras.optimizers.Adam(1e-3),
        loss=focal_loss(gamma=2.0, alpha=0.25),
        metrics=["accuracy"],
    )
    return model

# --------------------------------------------
# 5. クラス別 Data Augmentation パラメータ
#    「各数字ごとに」どれくらい回転・平行移動・拡大縮小・せん断をかけるかを指定。
#
#  - 0, 2 : UMAP 上でわりと孤立していて簡単 → 弱めの Aug
#  - 1,4,7,9 : 互いにクラスタが近く、混同しやすい → 強めの shift / rotation / zoom
#  - 3,5,8 : 中央部分の形が似ており、3 を中心に混ざりやすい → 強め
#  - 6 : UMAP でばらつきが大きいクラス → shear も入れて形の変化に強くする
# --------------------------------------------
digit_aug_params = {
    # 0: ほぼ真ん中の丸い形 → ちょっとの回転/平行移動/拡大のみ
    0: dict(
        rotation_range=5,        # 最大 ±5° だけ回転
        width_shift_range=0.05,  # 横方向に最大 5% シフト
        height_shift_range=0.05, # 縦方向に最大 5% シフト
        zoom_range=0.05          # 拡大縮小を ±5% だけ
    ),

    # 2: 若干のバリエーションを許容したいので 0 より少し強め
    2: dict(
        rotation_range=8,
        width_shift_range=0.08,
        height_shift_range=0.08,
        zoom_range=0.08
    ),

    # 1,4,7,9: 書き方が似ていて UMAP でも近い
    # → しっかり回転/シフト/ズームさせて、
    #   境界の形を多めに学習させる
    1: dict(
        rotation_range=15,
        width_shift_range=0.15,
        height_shift_range=0.15,
        zoom_range=0.15
    ),
    4: dict(
        rotation_range=15,
        width_shift_range=0.15,
        height_shift_range=0.15,
        zoom_range=0.15
    ),
    7: dict(
        rotation_range=15,
        width_shift_range=0.15,
        height_shift_range=0.15,
        zoom_range=0.15
    ),
    9: dict(
        rotation_range=15,
        width_shift_range=0.15,
        height_shift_range=0.15,
        zoom_range=0.15
    ),

    # 3,5,8: 丸みの部分が似ており、
    #        UMAP 上でも重なりやすい
    # → こちらも強めの Aug で
    #   境界をはっきりさせる
    3: dict(
        rotation_range=15,
        width_shift_range=0.15,
        height_shift_range=0.15,
        zoom_range=0.15
    ),
    5: dict(
        rotation_range=15,
        width_shift_range=0.15,
        height_shift_range=0.15,
        zoom_range=0.15
    ),
    8: dict(
        rotation_range=15,
        width_shift_range=0.15,
        height_shift_range=0.15,
        zoom_range=0.15
    ),

    # 6: 形状が diffuse（ばらけている）クラス
    # → shift/zoom をさらに強め、せん断(shear)も入れて形の変化に頑強にする
    6: dict(
        rotation_range=20,
        width_shift_range=0.20,
        height_shift_range=0.20,
        shear_range=10,          # せん断変換（斜めにゆがめる）を最大 10°
        zoom_range=0.20
    ),
}

# 各 digit 用の ImageDataGenerator を作る
# → 同じ ImageDataGenerator ではなく「クラスごとに別の変形ルール」を使う
digit_datagen = {
    d: ImageDataGenerator(**params)
    for d, params in digit_aug_params.items()
}

# --------------------------------------------
# 6. クラスごとの重み（class_weight）
#    → UMAP でばらつきや境界の曖昧さを見て、
#       難しいクラス(6, 3,5,8...) に少しだけ重みを大きくする。
#    → Keras の generator では class_weight 引数が使えないので、
#       あとで sample_weight に変換して使う。
# --------------------------------------------
class_weight = {
    0: 1.00,  # 簡単
    1: 1.05,
    2: 1.00,  # 簡単
    3: 1.10,  # 3,5,8 は混同しやすい
    4: 1.05,
    5: 1.10,
    6: 1.20,  # 最もばらつきが大きい → 重みを高めに
    7: 1.05,
    8: 1.10,
    9: 1.05,
}

# --------------------------------------------
# 7. クラス別 augmentation を実行するジェネレータ
#    - X, y を受け取り、ミニバッチごとに:
#        ・各サンプルのラベルに応じた digit_datagen[label] で変形
#        ・ラベルごとに class_weight を sample_weight に変換
#      を行い、(x_batch, y_batch, sample_weight_batch) を返す。
# --------------------------------------------
def classwise_generator(X, y, batch_size):
    n = len(X)
    indices = np.arange(n)
    while True:
        np.random.shuffle(indices)
        for start in range(0, n, batch_size):
            end = min(start + batch_size, n)
            batch_ids = indices[start:end]

            # 空のバッチを用意
            batch_x = np.empty((len(batch_ids), 28, 28, 1), dtype=np.float32)
            batch_y = y[batch_ids]

            # ラベルごとにクラス重みを取り出して、サンプル重みに変換
            batch_sw = np.array(
                [class_weight[int(lbl)] for lbl in batch_y],
                dtype=np.float32
            )

            # 1 サンプルずつ、「その digit 専用の ImageDataGenerator」で変形する
            for j, idx in enumerate(batch_ids):
                img = X[idx]
                label = int(y[idx])
                gen = digit_datagen[label]      # その digit 専用の Aug 設定
                batch_x[j] = gen.random_transform(img)  # 1 枚だけランダム変形

            # Keras は (x, y, sample_weight) 形式で重み付き学習ができる
            yield batch_x, batch_y, batch_sw

# --------------------------------------------
# 8. 5-fold Stratified KFold で Cross Validation + Test Ensemble
# --------------------------------------------
N_FOLDS = 5
EPOCHS = 40
BATCH_SIZE = 128

skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)

# 各 fold の test 予測を平均するためのバッファ
test_preds = np.zeros((X_test.shape[0], NUM_CLASSES), dtype="float32")

for fold, (train_idx, val_idx) in enumerate(skf.split(X, y), 1):
    print(f"===== FOLD {fold}/{N_FOLDS} =====")

    X_tr, X_val = X[train_idx], X[val_idx]
    y_tr, y_val = y[train_idx], y[val_idx]

    model = build_model()

    ckpt_path = f"best_model_fold{fold}_umap_classaug.weights.h5"

    # 各 fold ごとに validation accuracy が最高の時点の重みだけ保存
    ckpt = keras.callbacks.ModelCheckpoint(
        ckpt_path,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
        verbose=1,
    )

    # 早期終了（validation accuracy が 8 epoch 連続で改善しなければ打ち切り）
    es = keras.callbacks.EarlyStopping(
        monitor="val_accuracy",
        patience=8,
        restore_best_weights=False,  # ベスト重みは ckpt から読むので False でOK
        verbose=1,
    )

    # validation loss が良くならなくなったら学習率を 1/2 に下げる
    rlrop = keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss",
        factor=0.5,
        patience=4,
        min_lr=1e-5,
        verbose=1,
    )

    # クラス別 augmentation + sample_weight 付きのジェネレータ
    train_gen = classwise_generator(X_tr, y_tr, batch_size=BATCH_SIZE)
    steps_per_epoch = len(X_tr) // BATCH_SIZE

    # fit のときは class_weight ではなく、
    # ジェネレータから渡される sample_weight を使用（Keras が自動で扱う）
    model.fit(
        train_gen,
        epochs=EPOCHS,
        steps_per_epoch=steps_per_epoch,
        validation_data=(X_val, y_val),  # validation では augmentation や重みは使わない
        callbacks=[ckpt, es, rlrop],
        verbose=2,
    )

    # 各foldでベストだった重みを読み込む
    model.load_weights(ckpt_path)

    # Test データに対して予測し、fold ごとに 1/5 ずつ足していく（単純平均 ensemble）
    test_preds += model.predict(X_test, batch_size=BATCH_SIZE, verbose=1) / N_FOLDS

# --------------------------------------------
# 9. Test 予測をラベルに変換して submission.csv を出力
# --------------------------------------------
final_test_labels = np.argmax(test_preds, axis=1)

submission = pd.DataFrame({
    "ImageId": np.arange(1, len(final_test_labels)+1),
    "Label": final_test_labels
})

submission.to_csv("submission_cnn_5fold_umap_classaug.csv", index=False)
