Few-Shot learning with Reptile

Few-Shot  learning with Reptile

Few-Shot Learning چیست؟

در بسیاری از سناریوهای واقعی، جمع‌آوری مقادیر بزرگ از داده‌های برچسب‌خورده برای آموزش مدل یادگیری ماشین زمان‌بر، گرانقیمت یا به سادگی غیرقابل انجام است. یادگیری با تعداد کمی از داده‌ها به هدف مقابله با این مشکل می‌پردازد با آموزش مدل‌های قادر به سریع تطبیق به وظایف جدید با تنها چند نمونه برچسب‌خورده. هدف از این کار استفاده از دانش پیشین کسب شده از وظایف مشابه برای تسهیل یادگیری در وظایف نامعلوم است.Few-shot learning با استفاده از الگوریتم Reptile یک روش یادگیری ماشین است.

معرفی Reptile:

Reptile الگوریتم یادگیری متا است که توسط الکس نیکول و جان شولمن در سال 2018 پیشنهاد شده است. این الگوریتم برای حل مسئله یادگیری با تعداد کم اثربخش طراحی شده است. Reptile روش بهینه‌سازی بر پایه گرادیان را دنبال می‌کند، جایی که یاد می‌گیرد پارامترهای اولیه مدل را به گونه‌ای بهینه کند که بتواند به سرعت به وظایف جدید تطبیق پیدا کند. ایده اصلی پشت Reptile تقلید فرآیند تکامل است که با به‌روزرسانی متوالی پارامترهای مدل به سمت رفتار متوسط یک مجموعه وظایف مرتبط است. الگوریتم Reptile توسط OpenAI توسعه داده شده است تا  meta-learning-model  بدون وابستگی به مدل انجام دهد. به طور خاص، این الگوریتم برای یادگیری سریع انجام وظایف جدید با حداقل آموزش (few-shot learning) طراحی شده است. این الگوریتم با استفاده از کاهش وزن‌ها به وسیلهٔ کاهش گرادیان تصادفی (Stochastic Gradient Descent) با استفاده از تفاوت بین وزن‌های آموزش دیده شده بر روی یک mini-batch  از داده‌های قبلاً دیده نشده(never-seen-before-Data) و وزن‌های مدل قبل از آموزش، در طول تعداد ثابتی از متا-تکرارها (meta-iterations) عمل می‌کند.  

چگونه Reptile کار می‌کند؟

بگذارید با توضیحات مرحله به مرحله، به اصل کار Reptile بپردازیم: 1. شروع: مدل با پارامترهای تصادفی مقداردهی اولیه می‌شود. 2. نمونه‌برداری وظایف: زیرمجموعه‌ای از وظایف به صورت تصادفی از مجموعه آموزش انتخاب می‌شود. 3. حلقه داخلی: برای هر وظیفه، مدل با استفاده از تعداد کمی نمونه برچسب‌خورده از آن وظیفه آموزش داده می‌شود. پارامترهای مدل با استفاده از نزول گرادیان به‌روزرسانی می‌شود. 4. حلقه خارجی: پس از آموزش بر روی هر وظیفه، پارامترهای مدل به سمت متوسط پارامترهای به‌روز شده از حلقه داخلی به‌روزرسانی می‌شود. این گام به مدل کمک می‌کند تا ويژگي‌های مشترک را در وظایف فراگیر ضبط کند. 5. تکرار: گام‌های 3 و 4 برای تعداد ثابت دفعات تکرار می‌شود. 6. تطبيق: پس از پایان آموزش، با استفاده از چند گام گراديان با استفاده از پارامترهای به روز شده از حلقه خارجي، مدل به سرعت به وظایف جدید تطبيق پيدا مي‌کند. الگوریتم Reptile ابتدا یک مدل اولیه را با استفاده از داده‌های آموزش کمی آموزش می‌دهد. سپس برای هر وظیفه جدید، مدل را با استفاده از داده‌های آموزش کمی برای آن وظیفه به روزرسانی می‌دهد. این به روزرسانی‌ها به صورت تکراری انجام می‌شود و هر برخورد جدید با داده‌ها باعث بهبود عملکرد مدل برای وظیفه جدید می‌شود. Reptile از یک فرآیند به روزرسانی ساده و کارآمد برای تطبیق مدل با وظایف جدید استفاده می‌کند. این الگوریتم به عنوان یک الگوریتم یادگیری تقویتی (RL) شناخته شده است و در موارد کاربردی مختلف، اثبات شده است که عملکرد خوبی دارد، به ویژه در مسائل few-shot learning.  

Reptile چندین مزيت براي يادگيري با تعداد كمي دارد:

1. كارآيي: Reptile قابليت سريع تطبيق به وظايف جديد را با استفاده از تنها چندين گام گراديان فراهم مي كند. اين ويژگي آن را مناسب براي سناريوهايي كه نياز به تطبيق در زمان واقعي دارند، قابل قبول مي كند. 2. سادگي: الگوريتم نسبتاً ساده براي پياده‌سازي است و نياز به تغييرات ساختاري پيچيده در مدل ندارد. 3. عموميت: Reptile با استفاده از متوسط به روز شده پارامترها در وظايف، نشان داد كه نكات شائع را ياد بگيرد كه عملكردي بهتر را در وظايف نامعلوم فراهم مي كند. با استفاده از الگوریتم Reptile، می‌توان به سرعت و با دقت، مدل را برای وظایف جدید آموزش داد و از آن در مسائل few-shot learning به خوبی بهره برد.

     بررسی در کد:

1. وارد کردن کتابخانه‌ها:

ما کتابخانه‌های مورد نیاز را وارد می‌کنیم، numpy برای محاسبات عددی و TensorFlow برای ساخت و آموزش شبکه عصبی.

import numpy as np

import tensorflow as tf

2. تعریف معماری مدل:

در اینجا، ما معماری مدل شبکه عصبی را با استفاده از API سریالی Keras تعریف می‌کنیم. مدل شامل دو لایه پنهان با 64 واحد هر کدام است که از تابع فعال‌سازی ReLU استفاده می‌کنند. شکل ورودی به عنوان input_dim مشخص شده است و لایه خروجی شامل output_dim واحد با تابع فعال‌سازی softmax است.
model = tf.keras.models.Sequential([

tf.keras.layers.Dense(64, activation='relu', input_shape=(input_dim,)),

tf.keras.layers.Dense(64, activation='relu'),

tf.keras.layers.Dense(output_dim, activation='softmax')

])

3. تعریف تابع خطا و بهینه‌ساز:

ما تابع خطا را به عنوان categorical cross-entropy تعریف می‌کنیم که برای مسائل چند دسته‌ای معمولاً استفاده می‌شود. بهینه‌ساز را به SGD تنظیم می‌کنیم با نرخ یادگیری (Lr) 0.01.

loss_fn = tf.keras.losses.CategoricalCrossentropy()

optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)

4.تنظیم تکرارهای حلقه داخلی و خارجی:

ما تعداد تکرارهای حلقه داخلی و خارجی را مشخص می‌کنیم. حلقه داخلی فرآیند آموزش مدل بر روی چند نمونه برچسب‌دار از هر وظیفه است، درحالی که حلقه خارجی وزن‌های اصلی مدل را         براساس میانگین به‌روزرسانی‌های حلقه داخلی به‌روز می‌کند.

num_inner_iterations = 10

num_outer_iterations = 100

5.حلقه آموزش Reptile:

حلقه آموزش Reptile شامل حلقه خارجی و داخلی است. درحلقه خارجی، ما برای تعداد تکرارهای ثابتی حلقه انجام می‌دهیم. در هر تکرار، ما به طور تصادفی یک زیرمجموعه از وظایف را از میان وظایف موجود (توسط num_inner_iterations) انتخاب می‌کنیم. ما یک کپی از مدل اصلی (model_copy) می‌سازیم و وزن‌های آن را با وزن‌های مدل اصلی مقداردهی اولیه می‌کنیم.

در حلقه داخلی، بر روی وظایف انتخاب شده حلقه می‌زنیم. برای هر وظیفه، ما چند نمونه برچسب‌دار را با استفاده از تابع sample_few_shot_examples نمونه‌برداری می‌کنیم              (x_train_task و y_train_task)، سپس با گذراندن این نمونه‌ها از طریق مدل کپی (model_copy) لاگیت‌ها (پیش‌بینی‌های غیرنرمال شده) را محاسبه می‌کنیم. ما با استفاده از loss_fn خطا را محاسبه کرده و با استفاده از یک نوار گرادیان گرفتن، گرادیان خطا نسبت به متغیرهای قابل آموزش مدل را محاسبه می‌کنیم.

for iteration in range(num_outer_iterations):

    tasks = np.random.choice(num_tasks, num_inner_iterations, replace=False)

    model_copy = tf.keras.models.clone_model(model)

    model_copy.set_weights(model.get_weights())

    
    for task in tasks:

        x_train_task, y_train_task = sample_few_shot_examples(task, num_examples_per_task)

        
        with tf.GradientTape() as tape:

            logits = model_copy(x_train_task)

            loss_value = loss_fn(y_train_task, logits)

        gradients = tape.gradient(loss_value, model_copy.trainable_variables) 
        optimizer.apply_gradients(zip(gradients, model_copy.trainable_variables))


    main_model_weights = model.get_weights()

    model_copy_weights = model_copy.get_weights()

    updated_weights = []

    for i in range(len(main_model_weights)):

        updated_weights.append(main_model_weights[i] + (model_copy_weights[i] - main_model_weights[i])) 
    model.set_weights(updated_weights)

6.یادگیری Few-Shot بر روی وظایف جدید:

پس از حلقه آموزش، ما می‌توانیم از مدل انطباق‌یافته برای یادگیری Few-Shot بر روی وظایف جدید استفاده کنیم. ما با استفاده از تابع sample_new_task یک وظیفه جدید را نمونه‌برداری می‌کنیم و تعدادی نمونه برچسب‌دار (x_test_task و y_test_task) را برای آزمایش از آن کار با استفاده از تابع sample_few_shot_example به دست می‌آوریم. در نهایت، مثال‌ها را از مدل عبور می‌دهیم تا پیش‌بینی‌هایی برای کار جدید به دست آوریم.

new_task = sample_new_task()

x_test_task, y_test_task = sample_few_shot_examples(new_task, num_test_examples_per_task)

predictions = model(x_test_task)
خروجی مدل برای داده‌های تست وظیفه جدید (predictions) به دست می‌آید. این خروجی شامل پیش‌بینی‌های مدل برای داده‌های تست است و معمولاً یک آرایه از احتمالات یا برچسب‌های    پیش‌بینی شده برای هر نمونه داده است.  

سوالات:

سوال 1: چه کسانی الگوریتم Reptile را پیشنهاد داده‌اند؟ الف) جان شولمن ب) الکس نیکول ج) جان شولمن و الکس نیکول د) الگوریتم Reptile توسط هیچ کدام از این افراد پیشنهاد نشده است. سوال 2: Reptile چه نوع الگوریتمی است؟ الف) الگوریتم یادگیری عمیق ب) الگوریتم یادگیری متا ج) الگوریتم یادگیری تقویتی د) الگوریتم یادگیری تصادفی سوال 3: چه وظایفی برای مدل‌های یادگیری ماشین مناسب استفاده از الگوریتم Reptile است؟ الف) وظایف با تعداد کم داده‌های برچسب‌خورده ب) وظایف با تعداد زیاد داده‌های برچسب‌خورده ج) وظایف با داده‌های برچسب‌خورده غیرقابل دسترس د) هیچ کدام از موارد فوق سوال 4: چه مزایایی برای یادگیری با تعداد کم داده‌ها به کمک الگوریتم Reptile وجود دارد؟ الف) کارآیی، سادگی، عمومیت ب) سرعت، دقت، عمومیت ج) کارآیی، سرعت، سادگی د) سرعت، دقت، سادگی سوال 5: چه مراحل اصلی در کارکرد الگوریتم Reptile وجود دارد؟ الف) شروع، نمونه‌برداری وظایف، حلقه داخلی، حلقه خارجی، تکرار، تطبیق ب) شروع، نمونه‌برداری وظایف، حلقه داخلی، حلقه خارجی، تکرار، به‌روزرسانی ج) شروع، نمونه‌برداری وظایف، حلقه داخلی، حلقه خارجی، تطبیق، به‌روزرسانی د) شروع، نمونه‌برداری وظایف، حلقه داخلی، حلقه خارجی،تطبیق، تکرار